In [1]:
from torch.nn import Conv2d, MaxPool2d, Dropout, Linear, ReLU, CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import torchvision
import pathlib
import pickle
import torch
import os
import gc

In [2]:
class BevDataset(Dataset):
  def __init__(self, root, size=224, split='train'):
    self.split = split
    postfix = split
    root = os.path.join(root, 'bev_classification', 'images')
    self.dataset_folder = torchvision.datasets.ImageFolder(os.path.join(root, postfix) ,transform = transforms.Compose([transforms.Resize((size,size)),transforms.ToTensor()]))

  def __getitem__(self,index):
    img = self.dataset_folder[index]
    path = self.dataset_folder.imgs[index]
    return img[0], img[1], path[0]
  
  def __len__(self):
    return len(self.dataset_folder)
  
# class BevTestDataset(Dataset):
#   def __init__(self, root, size=224, split='train'):
#     postfix = split
#     root = os.path.join(root, 'bev_classification', 'images')
#     self.dataset_folder = torchvision.datasets.ImageFolder(os.path.join(root, postfix) ,transform = transforms.Compose([transforms.Resize((size,size)),transforms.ToTensor()]))

#   def __getitem__(self,index):
#     img = self.dataset_folder[index]
#     path = self.dataset_folder.imgs[index]
#     return img[0], img[1], path[0]
  
#   def __len__(self):
#     return len(self.dataset_folder)

In [3]:
class ImageClassifier(nn.Module):
    def __init__(self, dropout=0.2):
        super(ImageClassifier, self).__init__()
        output = 99
        self.dropout = Dropout(dropout)
        self.fc1 = Linear(512*32*32, output)
        self.fc2 = Linear(output, output)
        self.conv1 = Conv2d(3, 64, (3,3), padding=(1,1))
        self.conv2 = Conv2d(64, 128, (3,3), padding=(1,1))
        self.conv3 = Conv2d(128, 256, (3,3), padding=(1,1))
        self.conv4 = Conv2d(256, 512, (3,3), padding=(1,1))
        
        self.net = nn.Sequential(
            # Image size = 512 x 512 x 3
            self.conv1, 
            ReLU(),
            MaxPool2d(kernel_size=2, stride=2),

            # Image size = 256 x 256 x 64
            self.conv2, 
            ReLU(),
            Dropout(dropout),
            MaxPool2d(kernel_size=2, stride=2),

            # Image size = 128 x 128 x 128
            self.conv3, 
            ReLU(),
            MaxPool2d(kernel_size=2, stride=2),

            # Image size = 64 x 64 x 256
            self.conv4,
            ReLU(),
            Dropout(dropout),
            MaxPool2d(kernel_size=2, stride=2)
            # Image size = 32 x 32 x 512
        )

        self.initialize_weights()


    def initialize_weights(self):
        gain = 2**(1/2)
        nn.init.xavier_normal_(self.fc1.weight, gain=gain)
        nn.init.xavier_normal_(self.fc2.weight, gain=gain)
        nn.init.xavier_normal_(self.conv1.weight, gain=gain)
        nn.init.xavier_normal_(self.conv2.weight, gain=gain)
        nn.init.xavier_normal_(self.conv3.weight, gain=gain)
        nn.init.xavier_normal_(self.conv4.weight, gain=gain)


    def forward(self, X):
        output = self.net(X)
        _, c, h, w = output.size()
        output = output.view(-1, c*h*w)
        output = self.dropout(self.fc1(output))
        return self.fc2(output)

In [4]:
# Hyper parameters
# With batch_size 50, there will be 1776 iterations over the dataset per epoch
num_chunks = 10
batch_size = 5
num_epochs = 2
lr = 5e-3

In [5]:
train_losses = []
train_accuracy = []

val_losses = [0]
val_accuracy = [0]

model_path = 'model/mps-model.pkl'

if not pathlib.Path('model').exists():
    pathlib.Path('model').mkdir()

def train():
    try:
        gc.collect()
        device = torch.device('mps')
        if not pathlib.Path(model_path).exists():
            model = ImageClassifier().to(device)
        else:
            print('Model Found!')
            print('Loading model...')
            with open(model_path, 'rb') as f:
                model = pickle.load(f).to(device)

        # TODO: make sure to split the data into 10 samples and train on each 
        for i in range(num_chunks):
            print()
            train_loader = DataLoader(BevDataset('.', chunk=i), batch_size=batch_size, shuffle=True)
            val_loader = DataLoader(BevDataset('.', split='val', chunk=i), batch_size=batch_size, shuffle=True)
            
            # Only have 10 validation checks per epoch
            val_check = len(train_loader) // 5

            criterion = CrossEntropyLoss()
            optimizer = optim.Adam(model.parameters(), lr=lr)

            loop = tqdm(total=len(train_loader)*num_epochs, position=0)

            for epoch in range(num_epochs):
                train_step_losses = []
                train_step_accuracy = []
                for batch, (x, y_truth) in enumerate(train_loader):
                    x, y_truth = x.to(device), y_truth.to(device)

                    optimizer.zero_grad()

                    y_hat = model(x)

                    accuracy = (y_hat.argmax(1) == y_truth).float().mean()
                    train_accuracy.append((batch*epoch, accuracy.item()))

                    loss = criterion(y_hat, y_truth)
                    train_losses.append(loss.item())

                    loss.backward()
                    optimizer.step()

                    if (batch + 1) % val_check == 0:
                        print('Validation check')
                        val_loop = tqdm(total=len(val_loader), position=0)

                        val_batch_loss = []
                        val_batch_accuracy = []
                        for batch, (x, y_truth) in enumerate(val_loader):
                            x, y_truth = x.to(device), y_truth.to(device)

                            optimizer.zero_grad()

                            y_hat = model(x)

                            accuracy = (F.softmax(y_hat,1).argmax(1) == y_truth).float().mean()
                            val_batch_accuracy.append((batch*epoch, accuracy.item()))

                            loss = criterion(y_hat, y_truth)
                            val_batch_loss.append(loss.item())

                            loss.backward()
                            optimizer.step()
                            val_loop.update(1)
                            val_loop.set_description(f'val batch: {batch} val accuracy: {accuracy*100:.2f}% val loss: {loss:.4f}')

                        val_losses.append(sum(val_batch_loss) / len(val_batch_loss))
                        val_accuracy.append(sum(val_batch_accuracy) / len(val_batch_accuracy))

                    loop.update(1)
                    loop.set_description(f'epoch: {epoch+1} batch: {batch} accuracy: {accuracy*100:.2f}% val accuracy {val_accuracy[-1]*100:.2f}% loss: {loss:.4f} val loss: {val_losses[-1]:.4f}')
                
                print('Saving model...')
                with open(model_path, 'wb') as f:
                    pickle.dump(model, f)
                print('Model saved.')
    except KeyboardInterrupt:
        print('Saving model...')
        with open(model_path, 'wb') as f:
            pickle.dump(model, f)
        print('Model saved.')


# train()


In [6]:
import torchvision.models as models

class VGGIntermediate(nn.Module):
    def __init__(self):
        super(VGGIntermediate, self).__init__()
        self.vgg = models.vgg19(pretrained=True)
        self.set_up_vgg()

    def set_up_vgg(self):
        for param in self.vgg.parameters():
            param.requires_grad = False

        num_features = self.vgg.classifier[-1].in_features  # Get the number of input features for the final layer
        self.vgg.classifier[-1] = Linear(num_features, 99)
        # Optionally, you may want to initialize the new layer weights
        # Initialize weights with Xavier initialization
        torch.nn.init.xavier_uniform_(self.vgg.classifier[-1].weight)
        # Optionally, initialize biases to zeros
        torch.nn.init.zeros_(self.vgg.classifier[-1].bias)

    def forward(self, X):
        return self.vgg(X)

In [11]:
# Hyper parameters
# With batch_size 50, there will be 1776 iterations over the dataset per epoch
batch_size = 8
num_epochs = 2
lr = 5e-4

In [13]:
train_losses = [0]
train_accuracy = [0]

val_losses = [0]
val_accuracy = [0]

model_path = 'model/mps-model.pkl'

In [14]:

try:
    gc.collect()
    device = torch.device('mps')
    if not pathlib.Path(model_path).exists():
        model = VGGIntermediate().to(device)
    else:
        print('Model Found!')
        print('Loading model...')
        with open(model_path, 'rb') as f:
            model = pickle.load(f).to(device)

    for i in range(num_chunks):
        print()
        train_loader = DataLoader(BevDataset('.'), batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(BevDataset('.', split='val'), batch_size=batch_size, shuffle=True)
        
        # Only have 10 validation checks per epoch
        val_check = len(train_loader) // 5

        criterion = CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        loop = tqdm(total=len(train_loader)*num_epochs, position=0)

        for epoch in range(num_epochs):
            train_step_losses = []
            train_step_accuracy = []
            for batch, (x, y_truth, _) in enumerate(train_loader):
                x, y_truth = x.to(device), y_truth.to(device)

                optimizer.zero_grad()

                y_hat = model(x)

                accuracy = (y_hat.argmax(1) == y_truth).float().mean()
                train_step_accuracy.append(accuracy.item())

                loss = criterion(y_hat, y_truth)
                train_step_losses.append(loss.item())

                loss.backward()
                optimizer.step()

                if (batch + 1) % val_check == 0:
                    print('Validation check')
                    val_loop = tqdm(total=len(val_loader), position=0)

                    val_batch_loss = []
                    val_batch_accuracy = []
                    for batch, (x, y_truth, _) in enumerate(val_loader):
                        x, y_truth = x.to(device), y_truth.to(device)

                        optimizer.zero_grad()

                        y_hat = model(x)

                        accuracy = (F.softmax(y_hat,1).argmax(1) == y_truth).float().mean()
                        val_batch_accuracy.append(accuracy.item())

                        loss = criterion(y_hat, y_truth)
                        val_batch_loss.append(loss.item())

                        loss.backward()
                        optimizer.step()
                        val_loop.update(1)
                        val_loop.set_description(f'val batch: {batch} val accuracy: {accuracy*100:.2f}% val loss: {loss:.4f}')

                    val_losses.append(sum(val_batch_loss) / len(val_batch_loss))
                    val_accuracy.append(sum(val_batch_accuracy) / len(val_batch_accuracy))

                    train_losses.append(sum(train_step_losses) / len(train_step_losses))
                    train_accuracy.append(sum(train_step_accuracy) / len(train_step_accuracy))

                loop.update(1)
                loop.set_description(f'epoch: {epoch+1} batch: {batch} accuracy: {train_accuracy[-1]*100:.2f}% val accuracy {val_accuracy[-1]*100:.2f}% loss: {train_losses[-1]:.4f} val loss: {val_losses[-1]:.4f}')
            
            print('Saving model...')
            with open(model_path, 'wb') as f:
                pickle.dump(model, f)
            print('Model saved.')
except KeyboardInterrupt:
    print('Saving model...')
    with open(model_path, 'wb') as f:
        pickle.dump(model, f)
    print('Model saved.')

Model Found!
Loading model...



epoch: 2 batch: 2204 accuracy: 0.00% val accuracy 53.15% loss: 0.0000 val loss: 8.8426:  60%|█████▉    | 13309/22208 [1:08:12<45:36,  3.25it/s]
epoch: 1 batch: 1941 accuracy: 0.00% val accuracy 0.00% loss: 0.0000 val loss: 0.0000:  10%|▉         | 1942/19432 [07:03<1:04:13,  4.54it/s]

Validation check


val batch: 1580 val accuracy: 100.00% val loss: -0.0000: 100%|██████████| 1581/1581 [19:52<00:00,  1.33it/s]
epoch: 1 batch: 3884 accuracy: 54.72% val accuracy 59.51% loss: 8.5756 val loss: 6.6549:  20%|█▉        | 3885/19432 [19:10<57:37,  4.50it/s]    

Validation check


val batch: 1382 val accuracy: 40.00% val loss: 6.8572: 100%|██████████| 1383/1383 [12:07<00:00,  1.90it/s]
epoch: 1 batch: 5827 accuracy: 55.36% val accuracy 60.45% loss: 8.3701 val loss: 6.3343:  30%|██▉       | 5828/19432 [32:17<59:15,  3.83it/s]    

Validation check


val batch: 1382 val accuracy: 80.00% val loss: 4.2790: 100%|██████████| 1383/1383 [13:06<00:00,  1.76it/s]
epoch: 1 batch: 7770 accuracy: 55.76% val accuracy 60.97% loss: 8.1723 val loss: 5.9848:  40%|███▉      | 7771/19432 [46:05<47:28,  4.09it/s]     

Validation check


val batch: 1382 val accuracy: 60.00% val loss: 6.7955: 100%|██████████| 1383/1383 [13:47<00:00,  1.67it/s]
epoch: 1 batch: 9713 accuracy: 56.05% val accuracy 62.45% loss: 8.0235 val loss: 5.6726:  50%|████▉     | 9714/19432 [59:32<39:22,  4.11it/s]     

Validation check


val batch: 1382 val accuracy: 80.00% val loss: 0.9140: 100%|██████████| 1383/1383 [13:26<00:00,  1.71it/s]
epoch: 1 batch: 9715 accuracy: 56.21% val accuracy 61.53% loss: 7.9182 val loss: 5.6270:  50%|█████     | 9716/19432 [1:05:07<190:15:20, 70.49s/it] 

Saving model...
Model saved.


epoch: 2 batch: 1941 accuracy: 56.21% val accuracy 61.53% loss: 7.9182 val loss: 5.6270:  60%|█████▉    | 11658/19432 [1:13:02<31:46,  4.08it/s]   

Validation check


val batch: 1382 val accuracy: 40.00% val loss: 5.7219: 100%|██████████| 1383/1383 [13:29<00:00,  1.71it/s]
epoch: 2 batch: 3884 accuracy: 57.72% val accuracy 62.52% loss: 7.0769 val loss: 5.3167:  70%|██████▉   | 13601/19432 [1:46:26<21:02,  4.62it/s]     

Validation check


val batch: 1382 val accuracy: 40.00% val loss: 9.6302: 100%|██████████| 1383/1383 [33:24<00:00,  1.45s/it]
epoch: 2 batch: 5827 accuracy: 57.95% val accuracy 63.44% loss: 6.9603 val loss: 5.0599:  80%|███████▉  | 15544/19432 [1:58:28<14:16,  4.54it/s]    

Validation check


val batch: 1382 val accuracy: 80.00% val loss: 0.6037: 100%|██████████| 1383/1383 [12:01<00:00,  1.92it/s]
epoch: 2 batch: 7770 accuracy: 57.96% val accuracy 64.46% loss: 6.9337 val loss: 4.8785:  90%|████████▉ | 17487/19432 [2:10:29<07:33,  4.29it/s]   

Validation check


val batch: 1382 val accuracy: 80.00% val loss: 1.2704: 100%|██████████| 1383/1383 [12:00<00:00,  1.92it/s]
epoch: 2 batch: 9713 accuracy: 57.97% val accuracy 64.16% loss: 6.9000 val loss: 4.8929: 100%|█████████▉| 19430/19432 [2:22:34<00:00,  4.75it/s]   

Validation check


val batch: 1382 val accuracy: 60.00% val loss: 4.2272: 100%|██████████| 1383/1383 [12:05<00:00,  1.91it/s]
epoch: 2 batch: 9715 accuracy: 58.06% val accuracy 63.89% loss: 6.8374 val loss: 4.7125: 100%|██████████| 19432/19432 [2:27:35<00:00, 63.17s/it]

Saving model...
Model saved.



epoch: 2 batch: 9715 accuracy: 58.06% val accuracy 63.89% loss: 6.8374 val loss: 4.7125: 100%|██████████| 19432/19432 [2:27:36<00:00,  2.19it/s]
epoch: 1 batch: 60 accuracy: 58.06% val accuracy 63.89% loss: 6.8374 val loss: 4.7125:   0%|          | 61/19432 [00:13<1:08:43,  4.70it/s]

Saving model...
Model saved.
