In [1]:
from models.model import ResNet18
from models.utils import train_step, eval_step, DataLoaders
from models.transforms import transforms_resnet
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10

In [3]:
train = CIFAR10(root='data',transform=transforms_resnet,download=True)
test = CIFAR10(root='data',transform=transforms_resnet,train=False,download=True)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
BATCH_SIZE = 64
N_EPOCHS = 30

In [9]:
dl = DataLoaders(train,test,'resnet',BATCH_SIZE,True,'cifar10')

In [10]:
train_loader, test_loader = dl.get_loaders()

In [11]:
# Train and then Evaluate (Three Different Train and Evaluation Loops)
from tqdm import tqdm
def train_and_eval(train_loader,test_loader,model,loss_fn,optimizer,device,modeltype):
    tr_metric = {"Accuracy":[],"Loss":[]}
    ts_metric = {"Accuracy":[],"Loss":[]}

    for epoch in tqdm(range(N_EPOCHS)):
        tr_loss, tr_acc = train_step(model,train_loader,loss_fn,optimizer,device,modeltype)
        ts_loss, ts_acc = eval_step(model,test_loader,loss_fn,device,modeltype,data="cifar10")

        tr_metric["Accuracy"].append(tr_acc)
        tr_metric["Loss"].append(tr_loss)

        ts_metric["Accuracy"].append(ts_acc)
        ts_metric["Loss"].append(ts_loss)
    
    return tr_metric, ts_metric

model = ResNet18(n_classes=10)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
device = 'cuda' if torch.cuda.is_available() else 'cpu'


tr,ts = train_and_eval(train_loader,test_loader,model,loss_func,optimizer,device,'resnet18')

print("Final Train Accuracy:",tr["Accuracy"][-1])
print("Final Test Accuracy:",ts["Accuracy"][-1])
print("Final Train Loss:",tr["Loss"][-1])
print("Final Test Loss:",ts["Loss"][-1])

100%|██████████| 30/30 [26:19<00:00, 52.66s/it]

Final Train Accuracy: 0.7905
Final Test Accuracy: 0.78209996
Final Train Loss: 0.6016889525496442
Final Test Loss: 0.6397212705794414





In [12]:
import pandas as pd
pd.DataFrame(tr).to_csv("resnet18_cifar10_tr_1.csv") 
pd.DataFrame(ts).to_csv("resnet18_cifar10_ts_1.csv")

In [14]:
torch.save(model.state_dict(),"resnet18_cifar10.pth")
model.load_state_dict(torch.load("resnet18_cifar10.pth"))   