In [2]:
from torch.utils.data import Dataset, DataLoader
from torch.nn import Linear, CrossEntropyLoss
from torchvision import transforms, models
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import torchvision
import pathlib
import pickle
import torch
import os
import gc

In [7]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [10]:
class ResnetWrapper(nn.Module):
    def __init__(self):
        super(ResnetWrapper, self).__init__()
        self.resnet = models.resnet152(pretrained=True)
        self.fc = Linear(2048, 99, bias=True)
        self.set_for_finetune()

    def set_for_finetune(self):
        nn.init.xavier_normal_(self.fc.weight)

        self.resnet.conv1.requires_grad_= False
        self.resnet.bn1.requires_grad_ = False
        self.resnet.layer1.requires_grad_ = False
        self.resnet.layer2.requires_grad_ = False
        self.resnet.layer3.requires_grad_ = False
        self.resnet.layer4.requires_grad_ = False
        self.resnet.fc = self.fc

    def forward(self, x):
        return self.resnet(x)


In [None]:
class BevDataset(Dataset):
  def __init__(self, root, size=224, split='train'):
    self.split = split
    postfix = split
    root = os.path.join(root, 'bev_classification', 'images')
    self.dataset_folder = torchvision.datasets.ImageFolder(os.path.join(root, postfix) ,transform = transforms.Compose([transforms.Resize((size,size)),transforms.ToTensor()]))

  def __getitem__(self,index):
    img = self.dataset_folder[index]
    path = self.dataset_folder.imgs[index]
    return img[0], img[1], path[0]

  def __len__(self):
    return len(self.dataset_folder)

In [None]:
# Hyper parameters
# With batch_size 50, there will be 1776 iterations over the dataset per epoch
batch_size = 5
num_epochs = 5
lr = 1e-4

epoch: 1 batch: 2025 accuracy: 22.64% val accuracy 39.20% loss: 3.2424 val loss: 2.3560:   4%|▍         | 2026/48580 [05:31<1:10:22, 11.03it/s]

In [None]:
train_losses = [0]
train_accuracy = [0]

val_losses = [0]
val_accuracy = [0]

model_path = 'model/mps-resnet-model.pkl'

In [None]:
def fine_tune():
    try:
        gc.collect()
        device = torch.device('mps')
        if not pathlib.Path(model_path).exists():
            model = ResnetWrapper().to(device)
        else:
            print('Model Found!')
            print('Loading model...')
            with open(model_path, 'rb') as f:
                model = pickle.load(f).to(device)

        print()
        train_loader = DataLoader(BevDataset('.'), batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(BevDataset('.', split='val'), batch_size=batch_size, shuffle=True)

        # Only have 10 validation checks per epoch
        val_check = len(train_loader) // 5

        criterion = CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        loop = tqdm(total=len(train_loader)*num_epochs, position=0)

        for epoch in range(num_epochs):
            train_step_losses = []
            train_step_accuracy = []
            for batch, (x, y_truth, _) in enumerate(train_loader):
                x, y_truth = x.to(device), y_truth.to(device)

                optimizer.zero_grad()

                y_hat = model(x)

                accuracy = (y_hat.argmax(1) == y_truth).float().mean()
                train_step_accuracy.append(accuracy.item())

                loss = criterion(y_hat, y_truth)
                train_step_losses.append(loss.item())

                loss.backward()
                optimizer.step()

                if (batch + 1) % val_check == 0:
                    print('Validation check')
                    val_loop = tqdm(total=len(val_loader), position=0)

                    val_batch_loss = []
                    val_batch_accuracy = []
                    for batch, (x, y_truth, _) in enumerate(val_loader):
                        x, y_truth = x.to(device), y_truth.to(device)

                        optimizer.zero_grad()

                        y_hat = model(x)

                        accuracy = (F.softmax(y_hat,1).argmax(1) == y_truth).float().mean()
                        val_batch_accuracy.append(accuracy.item())

                        loss = criterion(y_hat, y_truth)
                        val_batch_loss.append(loss.item())

                        loss.backward()
                        optimizer.step()
                        val_loop.update(1)
                        val_loop.set_description(f'val batch: {batch} val accuracy: {accuracy*100:.2f}% val loss: {loss:.4f}')

                    val_losses.append(sum(val_batch_loss) / len(val_batch_loss))
                    val_accuracy.append(sum(val_batch_accuracy) / len(val_batch_accuracy))

                    train_losses.append(sum(train_step_losses) / len(train_step_losses))
                    train_accuracy.append(sum(train_step_accuracy) / len(train_step_accuracy))

                loop.update(1)
                loop.set_description(f'epoch: {epoch+1} batch: {batch} accuracy: {train_accuracy[-1]*100:.2f}% val accuracy {val_accuracy[-1]*100:.2f}% loss: {train_losses[-1]:.4f} val loss: {val_losses[-1]:.4f}')

            print('Saving model...')
            with open(model_path, 'wb') as f:
                pickle.dump(model, f)
            print('Model saved.')
    except KeyboardInterrupt:
        print('Saving model...')
        with open(model_path, 'wb') as f:
            pickle.dump(model, f)
        print('Model saved.')

# fine_tune()

Model Found!
Loading model...



epoch: 5 batch: 3886 accuracy: 64.60% val accuracy 71.68% loss: 1.2531 val loss: 1.0059: 100%|██████████| 19435/19435 [2:12:10<00:00,  2.45it/s]
epoch: 1 batch: 775 accuracy: 64.60% val accuracy 71.68% loss: 1.2531 val loss: 1.0059:   4%|▍         | 776/19435 [02:47<1:04:23,  4.83it/s]

Validation check


val batch: 553 val accuracy: 100.00% val loss: 0.2774: 100%|██████████| 554/554 [15:34<00:00,  1.69s/it]
epoch: 1 batch: 1552 accuracy: 65.42% val accuracy 72.22% loss: 1.2173 val loss: 0.9801:   8%|▊         | 1553/19435 [07:36<1:01:56,  4.81it/s]

Validation check


val batch: 553 val accuracy: 100.00% val loss: 0.9236: 100%|██████████| 554/554 [04:49<00:00,  1.91it/s]
epoch: 1 batch: 2329 accuracy: 65.51% val accuracy 71.97% loss: 1.2149 val loss: 0.9837:  12%|█▏        | 2330/19435 [12:24<58:53,  4.84it/s]

Validation check


val batch: 553 val accuracy: 100.00% val loss: 0.6252: 100%|██████████| 554/554 [04:47<00:00,  1.93it/s]
epoch: 1 batch: 3106 accuracy: 65.50% val accuracy 72.14% loss: 1.2184 val loss: 0.9725:  16%|█▌        | 3107/19435 [17:11<1:03:39,  4.27it/s]

Validation check


val batch: 553 val accuracy: 100.00% val loss: 0.0323: 100%|██████████| 554/554 [04:47<00:00,  1.92it/s]
epoch: 1 batch: 3883 accuracy: 65.54% val accuracy 72.27% loss: 1.2204 val loss: 0.9670:  20%|█▉        | 3884/19435 [21:59<53:36,  4.84it/s]

Validation check


val batch: 553 val accuracy: 100.00% val loss: 0.1130: 100%|██████████| 554/554 [04:47<00:00,  1.92it/s]
epoch: 1 batch: 3886 accuracy: 65.37% val accuracy 73.25% loss: 1.2228 val loss: 0.9494:  20%|██        | 3887/19435 [23:59<109:23:14, 25.33s/it]

Saving model...
Model saved.


epoch: 2 batch: 775 accuracy: 65.37% val accuracy 73.25% loss: 1.2228 val loss: 0.9494:  24%|██▍       | 4663/19435 [26:54<54:48,  4.49it/s]

Validation check


val batch: 553 val accuracy: 100.00% val loss: 0.2532: 100%|██████████| 554/554 [04:55<00:00,  1.88it/s]
epoch: 2 batch: 1552 accuracy: 66.29% val accuracy 73.19% loss: 1.1852 val loss: 0.9507:  28%|██▊       | 5440/19435 [31:44<49:10,  4.74it/s]

Validation check


val batch: 553 val accuracy: 100.00% val loss: 0.5377: 100%|██████████| 554/554 [04:49<00:00,  1.91it/s]
epoch: 2 batch: 2329 accuracy: 66.13% val accuracy 72.99% loss: 1.1884 val loss: 0.9416:  32%|███▏      | 6217/19435 [36:32<49:51,  4.42it/s]

Validation check


val batch: 553 val accuracy: 100.00% val loss: 0.1898: 100%|██████████| 554/554 [04:48<00:00,  1.92it/s]
epoch: 2 batch: 3106 accuracy: 65.74% val accuracy 72.93% loss: 1.2002 val loss: 0.9352:  36%|███▌      | 6994/19435 [41:18<43:41,  4.75it/s]

Validation check


val batch: 553 val accuracy: 100.00% val loss: 0.0033: 100%|██████████| 554/554 [04:46<00:00,  1.93it/s]
epoch: 2 batch: 3856 accuracy: 65.64% val accuracy 73.49% loss: 1.2013 val loss: 0.9265:  40%|███▉      | 7744/19435 [45:58<44:29,  4.38it/s]

In [17]:
def among_us():
    imposter = input("Imposter: ")
    amonguses = []
    while True:
        crewmate = input('Crewmate: ')
        if crewmate == '':
            break
        amonguses.append(crewmate)
    print('What happened? ')
    if imposter == 'Your Mother':
        print('Imposter wins bozo')
    else:
        for item in amonguses:
            print(f"{item} clowned on the imposter and ate part of {imposter}...")

In [18]:
among_us()

What happened? 
Imposter wins bozo
