In [1]:
from datasets import load_dataset
import torchvision.models as models
import torchvision
import torch.nn as nn
import torch
import os
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
random_seed = 42

In [2]:
def collate_fn(batch):
    transformations = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2296, 0.2263, 0.2255])
    ])
    imgs = [transformations(item['image'].convert('RGB')) for item in batch]
    lbs = [item['label'] for item in batch]
    return (imgs,lbs)

In [11]:
def make_dataloader(batch_size=128, num_workers=7, train=True):
    if train:
        trainset = load_dataset('Maysee/tiny-imagenet', split='train')
        return torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, shuffle_seed=random_seed, collate_fn=collate_fn)
    
    else:
        evalset = load_dataset('Maysee/tiny-imagenet', split='valid')
        return torch.utils.data.DataLoader(evalset, batch_size=batch_size, shuffle=True, shuffle_seed=random_seed, collate_fn=collate_fn)

In [12]:
from ResNet50_pytorch import ResNet

In [13]:
#model = init_model()
model = ResNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

epochs = 5
loss_sum = 0
print_frequency = 7

In [14]:
train_loader = make_dataloader(train=True)
val_loader = make_dataloader(train=False)

In [15]:
model.train(True)
epoch_loss = 0
for epoch in range(epochs):
    _train_loader = train_loader.__iter__()
    
    itr = 0
    for i, (batch, target) in enumerate(_train_loader, start=itr):
        if(i > len(train_loader)-2):
            break

        batch = torch.stack(batch, dim=0)
        target = torch.tensor(target)

        optimizer.zero_grad()
        
        output = model(batch)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()  # optimization update
        loss_sum+=loss.item()
        epoch_loss += loss.item()
        
        
        if i % print_frequency == 0:
            print(epoch, i, loss_sum/print_frequency)
            loss_sum=0
            
    print("epoch_loss", epoch_loss)
    epoch_loss = 0

0 0 0.7612942968096051
0 7 2.65656045565681
0 14 3.498048614178385
0 21 3.4260009740080153
0 28 3.43661061248609
0 35 3.162638438067266
0 42 2.484109029173851
0 49 3.9342805487768993
0 56 3.7016399077006747
0 63 3.392662295273372
0 70 3.3453642640795027
0 77 2.9021721865449632
0 84 3.9020178339311054
0 91 3.6820112679685866
0 98 3.609568510736738
0 105 2.765452401978629
0 112 3.460080095699855
0 119 4.320494174957275
0 126 4.085704045636313
0 133 3.8572702748434886
0 140 3.7640933990478516
0 147 3.8655047586985996
0 154 4.513715539659772
0 161 4.148005247116089
0 168 4.214187008993966
0 175 3.774543251310076
0 182 4.3071174791881015
0 189 4.586383308683123
0 196 4.396456360816956
0 203 4.4124767780303955
0 210 3.7985069240842546
0 217 4.758507285799299
0 224 4.545607669012887
0 231 4.3895823274339945
0 238 4.257874148232596
0 245 4.4646658556801935
0 252 4.859663384301322
0 259 4.612051827566964
0 266 4.670415844236102
0 273 4.546453237533569
0 280 4.556282690593174
0 287 4.84717328207

KeyboardInterrupt: 