In [1]:
import torch
import torchvision
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
import torchvision.models as models
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import cv2

In [2]:
SEED = 1

# Use GPU if available, otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print('Current Device:', device)

Current Device: cpu


### Load Dataset

In [3]:
# ImageNet mean and std
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

IMAGENET_TRANSFORMS = transforms.Compose(
    [
        # ResNet18 will resize images to 256x256, then takes a central crop of 224x224
        transforms.Resize((256,256)),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
    ]
)

In [4]:
data_train = torchvision.datasets.Imagenette(root='./datasets/imagenette', split='train', download=False, transform=IMAGENET_TRANSFORMS)
data_val = torchvision.datasets.Imagenette(root='./datasets/imagenette', split='val', download=False, transform=IMAGENET_TRANSFORMS)

In [5]:
class ImagenetteDataset(torch.utils.data.Dataset):

    def __init__(self, dataset, indices, transform=None):
        self.transforms = transform
        self.data = torch.utils.data.Subset(dataset, indices)

    def __getitem__(self, index):
        image, label = self.data[index]
        if self.transforms != None:
            image = self.transforms(image)

        label = 1 if label==3 else 0

        return image, label

    def __len__(self):
        return len(self.data)

In [6]:
data_train.class_to_idx

{'tench': 0,
 'Tinca tinca': 0,
 'English springer': 1,
 'English springer spaniel': 1,
 'cassette player': 2,
 'chain saw': 3,
 'chainsaw': 3,
 'church': 4,
 'church building': 4,
 'French horn': 5,
 'horn': 5,
 'garbage truck': 6,
 'dustcart': 6,
 'gas pump': 7,
 'gasoline pump': 7,
 'petrol pump': 7,
 'island dispenser': 7,
 'golf ball': 8,
 'parachute': 9,
 'chute': 9}

In [7]:
# Rebalance classes
# train_indices = [idx for idx, (image, label) in enumerate(data_train) if label==1 or label==4 or label==3]
# val_indices = [idx for idx, (image, label) in enumerate(data_val) if label==1 or label==4 or label==3]

# Ignore this for now because we're not actually training
train_indices = np.arange(len(data_train))
val_indices = np.arange(len(data_val))

In [8]:
data_train_chainsaw = ImagenetteDataset(data_train, train_indices)
data_val_chainsaw = ImagenetteDataset(data_val, val_indices)

In [9]:
train_loader = torch.utils.data.DataLoader(data_train_chainsaw, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(data_val_chainsaw, batch_size=64, shuffle=False)

### Train Model

In [10]:
from models.model import Resnet18, CNNClassifier

# model = CNNClassifier(num_classes=2)
model = Resnet18()
lr = 0.008
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [11]:
def train(model, data_loader, criterion, optimizer):
    model.train()
    model.to(device)

    running_loss = 0.
    running_steps = 0
    n_correct = 0
    n_total = 0

    with tqdm(data_loader, desc ="   train") as train_tqdm:
        for inputs, targets in train_tqdm:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_steps += 1

            _, pred = torch.max(outputs, 1)
            n_correct += (pred==targets).sum().item()
            n_total += len(targets)

            average_train_loss = running_loss / running_steps
            accuracy = n_correct / n_total * 100.

            train_tqdm.set_postfix(loss=average_train_loss, accuracy=accuracy)

    return average_train_loss, accuracy

In [12]:
@torch.no_grad()
def test(model, data_loader, criterion, optimizer):
    model.eval()
    model.to(device)

    running_loss = 0.
    running_steps = 0
    n_correct = 0
    n_total = 0

    with tqdm(data_loader, desc ="   test") as test_tqdm:
        for inputs, targets in test_tqdm:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, targets)
            running_loss += loss.item()
            running_steps += 1

            _, pred = torch.max(outputs, 1)
            n_correct += (pred==targets).sum().item()
            n_total += len(targets)
    
        average_test_loss = running_loss / running_steps
        accuracy = n_correct / n_total * 100.

        test_tqdm.set_postfix(loss=average_test_loss, accuracy=accuracy)

    return average_test_loss, accuracy

In [13]:
for i in range(0):
    train_avg_loss, train_accuracy = train(model, train_loader, criterion, optimizer)
    if (i+1)%10 == 0:
        val_avg_loss, val_accuracy = test(model, val_loader, criterion, optimizer)
        print('Epoch: {:3d}, Train Average Loss: {:.2f}, Train Accuracy: {:.1f}%, Validation Average Loss: {:.2f}, Validation Accuracy: {:.1f}%'\
              .format(i+1, train_avg_loss, train_accuracy, val_avg_loss, val_accuracy))

In [14]:
val_avg_loss, val_accuracy = test(model, val_loader, criterion, optimizer)
print('Validation Average Loss: {:.2f}, Validation Accuracy: {:.1f}%'\
             .format(val_avg_loss, val_accuracy))

### Save Model Parameters

In [15]:
torch.save(model.state_dict(), './models/default_imagenet.pth')