In [1]:
from torchvision.models import list_models,get_model
from torchvision.datasets import CIFAR100
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from model import GoogLeNet

In [2]:
googlenet = get_model("googlenet", weights=None)
resnet = get_model("resnet50", weights=None)
vgg = get_model("vgg11", weights=None)



In [3]:
class VGG(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = get_model("vgg11")
        self.linear = nn.Sequential(
            nn.Linear(1000,256),
            nn.ReLU(),
            nn.Linear(256,256),
            nn.ReLU(),
            nn.Linear(256,100)
        )
    
    def forward(self,x):
        out = self.vgg(x)
        out = self.linear(out)
        return out

In [4]:
train_tfm = transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


test_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
])
train_set = CIFAR100(root='./', train = True, download= True,transform=train_tfm)
test_set = CIFAR100(root='./', train = False, download= True,transform=test_tfm)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
res_config = {
    "lr":1e-4,
    "weight_decay":1e-4,
}
google_config = {
    "lr":1e-4,
    "weight_decay":1e-4,
}
vgg_config = {
    "model":GoogLeNet(num_classes=100, aux_logits=True, init_weights=True),
    "lr":0.0003,
    "weight_decay":0,
}

In [6]:
batch_size = 64
num_epoch = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
config = vgg_config

In [7]:
def collate_fn(batch):
    images,target = list(zip(*batch))
    target = nn.functional.one_hot(torch.tensor(target), num_classes=100).to(torch.float32)
    images = torch.stack(images)
    return images,target

In [8]:
train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_set,batch_size=batch_size,shuffle=True)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=config['model'].parameters(),lr=config['lr'],weight_decay=config['weight_decay'])

In [9]:
def train_step(batch,model,loss_fn,optimizer):
    global device
    images,target = batch
    images=images.to(device)
    target=target.to(device)
    
    optimizer.zero_grad()
    
#     output = model(images)
#     print(torch.softmax(output,-1))
#     print(target)
    output, aux_logits2, aux_logits1 = model(images)
    loss0 = loss_fn(output, target)
    loss1 = loss_fn(aux_logits1, target)
    loss2 = loss_fn(aux_logits2, target)
    loss = loss0 + loss1 * 0.3 + loss2 * 0.3

#     loss = loss_fn(output,target)
    loss.backward()
    optimizer.step()
    
    acc = (output.argmax(dim=-1) == target).float().mean()
#     acc = torch.sum(torch.argmax(torch.softmax(output,-1),dim=-1) == torch.argmax(target,dim=-1))/len(target)
#     print(torch.argmax(torch.softmax(output,-1),dim=-1) , torch.argmax(target,dim=-1))
    
    return loss.item(), acc.item()


def train_epoch(train_loader,model,loss_fn,optimizer,epoch):
    show_bar = tqdm(train_loader)
    show_bar.set_description(f'[Training Epoch: {epoch+1}]')
    acc_recoder = []
    loss_recoder = []
    for idx,batch in enumerate(show_bar):
        loss,acc = train_step(batch,model,loss_fn,optimizer)
        loss_recoder.append(loss)
        acc_recoder.append(acc)
        
        if (idx+1)%5==0:
#             break
            show_bar.set_postfix({'loss':f'{loss:.5f}','acc':f'{acc:.4f}'})
    return loss_recoder,acc_recoder

In [None]:
total_loss = []
total_acc = []
model = config['model']
model.train()
model.to(device)

for epoch in range(num_epoch):
    
    loss_recoder,acc_recoder = train_epoch(train_loader,model,loss_fn,optimizer,epoch)
    print(f"Epoch {epoch+1}: loss={sum(loss_recoder)/len(loss_recoder)} ,acc={sum(acc_recoder)/len(acc_recoder)}")
    total_loss+=loss_recoder
    total_acc+=total_acc

plt.plot(total_loss,"b")
plt.show()
plt.plot(total_acc,"r")
plt.show()

  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1: loss=6.629838353837543 ,acc=0.06613650895140664


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 2: loss=6.043308506231479 ,acc=0.11391064578005115


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 3: loss=5.784091513480067 ,acc=0.14426150895140666


  0%|          | 0/782 [00:00<?, ?it/s]