In [5]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.tensorboard import SummaryWriter
import time
import random
import numpy as np
from torchsummary import summary
from torch.utils.data import DataLoader

In [6]:
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None):
        super().__init__()
        sz = sz or (1,1)
        self.ap = nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

In [None]:
def get_model():
    model = models.resnet50(pretrained=True)
    
    for param in model.parameters():
        param.requires_grad = False
    
    model.avgpool = AdaptiveConcatPool2d()
    model.fc = nn.Sequential(
        nn.Flatten(),
        nn.BatchNorm1d(4096),
        nn.Dropout(0.5),
        nn.Linear(4096, 512),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Dropout(p=0.5),
        nn.Linear(512, 2),
        nn.LogSoftmax(dim=1)
    )
    return model

In [None]:
def train(model, device, train_loader, criterion, optimizer, epoch, writer):
    model.train()
    
    total_loss = 0
    
    for batch_id, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        preds = model(data)
        loss = criterion(preds, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    writer.add_scalar('Train/Loss', total_loss/len(train_loader), epoch)
    writer.flush()
    
    return total_loss

In [8]:
def test(model, device, test_loader, criterion, epoch, writer):
    model.eval()
    
    total_loss, correct = 0, 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item()
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).cpu().sum()
            
            misclassified_images(pred, writer, target, data, output, epoch)
            
    total_loss /= len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    writer.add_scalar('Test/Loss', total_loss, epoch)
    writer.add_scalar('Test/Accuracy', accuracy, epoch)
    writer.flush()
    
    return total_loss, accuracy

In [None]:
def misclassified_images(pred, writer, target, data, output, epoch, count=10):
    