In [3]:
from torch.utils.data import Dataset, DataLoader
from torch.nn import Linear, CrossEntropyLoss
from torchvision import transforms, models
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import torchvision
import pathlib
import pickle
import torch
import os
import gc

In [None]:
from google.colab import drive
drive.mount('/content/drive')

ModuleNotFoundError: No module named 'google.colab'

In [None]:
!unzip drive/MyDrive/hackathon_data.zip

In [None]:
import pathlib

images_path = 'bev_classification/images'
test_path = 'bev_classification/images/test'
train_path = 'bev_classification/images/train'

if not pathlib.Path(test_path).exists():
    pathlib.Path(test_path).mkdir()

if not pathlib.Path(train_path).exists():
    pathlib.Path(train_path).mkdir()

for file in pathlib.Path(images_path).iterdir():
    if file.name == '.DS_Store' or \
        file == test_path or \
        file == train_path:
        continue

    if 'test' in file.name:
        for data in pathlib.Path(file).iterdir():
            # Only iterate over the class directories
            if 'image-datasets' == data.name or 'input' in data.name or data.name == '.DS_Store': continue

            if pathlib.Path(f'{test_path}/{data.name}').exists():
                # print(data.name)
                for f in pathlib.Path(data).iterdir():
                    pathlib.Path(f).rename(f'{test_path}/{data.name}/{f.name}')

            else:
                pathlib.Path(data).rename(f'{test_path}/{data.name}')
            # print(f'{test_path}/{data.name}')
    else:
        for data in pathlib.Path(file).iterdir():
            # Only iterate over the class directories
            if 'image-datasets' == data.name or 'input' in data.name or data.name == '.DS_Store': continue

            if pathlib.Path(f'{train_path}/{data.name}').exists():

                for f in pathlib.Path(data).iterdir():
                    pathlib.Path(f).rename(f'{train_path}/{data.name}/{f.name}')

            else:
                pathlib.Path(data).rename(f'{train_path}/{data.name}')

In [None]:
import pathlib

train_path = 'bev_classification/images/train'
val_path = 'bev_classification/images/val'

if not pathlib.Path(val_path).exists():
    pathlib.Path(val_path).mkdir()

for id in pathlib.Path(train_path).iterdir():
    if not pathlib.Path(f"{val_path}/{id.name}").exists():
        pathlib.Path(f"{val_path}/{id.name}").mkdir()

    files = [file for file in pathlib.Path(id).iterdir()]
    val_files = files[:len(files)//8]

    [pathlib.Path(f).rename(f'{val_path}/{id.name}/{f.name}') for f in val_files]


In [19]:
class ResnetWrapper(nn.Module):
    def __init__(self):
        super(ResnetWrapper, self).__init__()
        self.resnet = models.resnet152(pretrained=True)
        self.fc = Linear(2048, 99, bias=True)
        self.set_for_finetune()

    def set_for_finetune(self):
        nn.init.xavier_normal_(self.fc.weight)

        for l in self.resnet.conv1.parameters():
            l.requires_grad = False
        for l in self.resnet.bn1.parameters():
            l.requires_grad = False
        for layer in [self.resnet.layer1, self.resnet.layer2, self.resnet.layer3, self.resnet.layer4]:
            for l in layer.parameters():
                l.requires_grad = False
        # self.resnet.layer1 #.requires_grad_ = False
        # self.resnet.layer2.requires_grad_ = False
        # self.resnet.layer3.requires_grad_ = False
        # self.resnet.layer4.requires_grad_ = False
        self.resnet.fc = self.fc

    def forward(self, x):
        return self.resnet(x)


In [20]:
class BevDataset(Dataset):
  def __init__(self, root, size=224, split='train'):
    self.split = split
    postfix = split
    root = os.path.join(root, 'bev_classification', 'images')
    self.dataset_folder = torchvision.datasets.ImageFolder(os.path.join(root, postfix) ,transform = transforms.Compose([transforms.Resize((size,size)),transforms.ToTensor()]))

  def __getitem__(self,index):
    img = self.dataset_folder[index]
    path = self.dataset_folder.imgs[index]
    return img[0], img[1], path[0]

  def __len__(self):
    return len(self.dataset_folder)

In [25]:
# Hyper parameters
# With batch_size 50, there will be 1776 iterations over the dataset per epoch
batch_size = 10
num_epochs = 5
lr = 1e-4

In [26]:
train_losses = [0]
train_accuracy = [0]

val_losses = [0]
val_accuracy = [0]

model_path = 'model/mps-resnet-model.pkl'

In [27]:
def fine_tune():
    try:
        gc.collect()
        device = torch.device('mps')
        if not pathlib.Path(model_path).exists():
            model = ResnetWrapper().to(device)
        else:
            print('Model Found!')
            print('Loading model...')
            with open(model_path, 'rb') as f:
                model = pickle.load(f).to(device)

        print()
        train_loader = DataLoader(BevDataset('.'), batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(BevDataset('.', split='val'), batch_size=batch_size, shuffle=True)

        # Only have 10 validation checks per epoch
        val_check = len(train_loader) // 5

        criterion = CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        loop = tqdm(total=len(train_loader)*num_epochs, position=0)

        for epoch in range(num_epochs):
            train_step_losses = []
            train_step_accuracy = []
            for batch, (x, y_truth, _) in enumerate(train_loader):
                x, y_truth = x.to(device), y_truth.to(device)

                optimizer.zero_grad()

                y_hat = model(x)

                accuracy = (y_hat.argmax(1) == y_truth).float().mean()
                train_step_accuracy.append(accuracy.item())

                loss = criterion(y_hat, y_truth)
                train_step_losses.append(loss.item())

                loss.backward()
                optimizer.step()

                if (batch + 1) % val_check == 0:
                    print('Validation check')
                    val_loop = tqdm(total=len(val_loader), position=0)

                    val_batch_loss = []
                    val_batch_accuracy = []
                    for batch, (x, y_truth, _) in enumerate(val_loader):
                        x, y_truth = x.to(device), y_truth.to(device)

                        optimizer.zero_grad()

                        y_hat = model(x)

                        accuracy = (F.softmax(y_hat,1).argmax(1) == y_truth).float().mean()
                        val_batch_accuracy.append(accuracy.item())

                        loss = criterion(y_hat, y_truth)
                        val_batch_loss.append(loss.item())

                        loss.backward()
                        optimizer.step()
                        val_loop.update(1)
                        val_loop.set_description(f'val batch: {batch} val accuracy: {accuracy*100:.2f}% val loss: {loss:.4f}')

                    val_losses.append(sum(val_batch_loss) / len(val_batch_loss))
                    val_accuracy.append(sum(val_batch_accuracy) / len(val_batch_accuracy))

                    train_losses.append(sum(train_step_losses) / len(train_step_losses))
                    train_accuracy.append(sum(train_step_accuracy) / len(train_step_accuracy))

                loop.update(1)
                loop.set_description(f'epoch: {epoch+1} batch: {batch} accuracy: {train_accuracy[-1]*100:.2f}% val accuracy {val_accuracy[-1]*100:.2f}% loss: {train_losses[-1]:.4f} val loss: {val_losses[-1]:.4f}')

            print('Saving model...')
            with open(model_path, 'wb') as f:
                pickle.dump(model, f)
            print('Model saved.')
    except KeyboardInterrupt:
        print('Saving model...')
        with open(model_path, 'wb') as f:
            pickle.dump(model, f)
        print('Model saved.')

fine_tune()

Model Found!
Loading model...



epoch: 1 batch: 1552 accuracy: 0.00% val accuracy 0.00% loss: 0.0000 val loss: 0.0000:   4%|▍         | 1553/38865 [08:27<3:16:27,  3.17it/s]

Validation check


epoch: 1 batch: 3106 accuracy: 29.44% val accuracy 57.25% loss: 3.3816 val loss: 2.1711:   8%|▊         | 3107/38865 [24:27<3:37:41,  2.74it/s]    

Validation check


val batch: 1106 val accuracy: 0.00% val loss: 4.3104: 100%|██████████| 1107/1107 [15:59<00:00,  1.15it/s]
epoch: 1 batch: 4660 accuracy: 47.57% val accuracy 70.34% loss: 2.5263 val loss: 1.3940:  12%|█▏        | 4661/38865 [40:50<3:36:29,  2.63it/s]    

Validation check


val batch: 1106 val accuracy: 0.00% val loss: 4.9071: 100%|██████████| 1107/1107 [16:23<00:00,  1.13it/s]
epoch: 1 batch: 6214 accuracy: 55.59% val accuracy 74.63% loss: 2.1010 val loss: 1.1023:  16%|█▌        | 6215/38865 [57:33<3:30:10,  2.59it/s]    

Validation check


val batch: 1106 val accuracy: 0.00% val loss: 4.5585: 100%|██████████| 1107/1107 [16:43<00:00,  1.10it/s]
epoch: 1 batch: 7768 accuracy: 60.27% val accuracy 77.06% loss: 1.8406 val loss: 0.9586:  20%|█▉        | 7769/38865 [1:14:59<3:30:07,  2.47it/s]    

Validation check


val batch: 1106 val accuracy: 0.00% val loss: 5.5779: 100%|██████████| 1107/1107 [17:25<00:00,  1.06it/s]
epoch: 1 batch: 7772 accuracy: 63.47% val accuracy 78.88% loss: 1.6634 val loss: 0.8565:  20%|██        | 7773/38865 [2:46:28<4882:38:15, 565.34s/it]  

Saving model...
Model saved.


epoch: 2 batch: 1552 accuracy: 63.47% val accuracy 78.88% loss: 1.6634 val loss: 0.8565:  24%|██▍       | 9326/38865 [3:04:58<2:49:34,  2.90it/s]    

Validation check


val batch: 1106 val accuracy: 0.00% val loss: 4.8080: 100%|██████████| 1107/1107 [1:49:59<00:00,  5.96s/it]
epoch: 2 batch: 3106 accuracy: 77.73% val accuracy 79.74% loss: 0.8672 val loss: 0.8079:  28%|██▊       | 10880/38865 [4:29:01<2:26:19,  3.19it/s]    

Validation check


val batch: 1106 val accuracy: 0.00% val loss: 5.1593: 100%|██████████| 1107/1107 [1:24:03<00:00,  4.56s/it]
epoch: 2 batch: 3135 accuracy: 78.11% val accuracy 80.24% loss: 0.8464 val loss: 0.7568:  28%|██▊       | 10909/38865 [6:11:59<5:19:27,  1.46it/s]      

Saving model...


epoch: 2 batch: 3135 accuracy: 78.11% val accuracy 80.24% loss: 0.8464 val loss: 0.7568:  28%|██▊       | 10909/38865 [6:12:00<15:53:18,  2.05s/it]
val batch: 1106 val accuracy: 0.00% val loss: 4.8710: 100%|██████████| 1107/1107 [1:42:57<00:00,  5.58s/it]

Model saved.





In [None]:
# Plot data points



In [9]:
def among_us():
    imposter = input("Imposter: ")
    amonguses = []
    while True:
        crewmate = input('Crewmate: ')
        if crewmate == '':
            break
        amonguses.append(crewmate)
    print('What happened? ')
    if imposter == 'Your Mother':
        print('Imposter wins bozo')
    else:
        for item in amonguses:
            print(f"{item} clowned on the imposter and ate part of {imposter}...")