In [1]:
%load_ext autoreload
%autoreload 2

if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')

    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    # set time zone to Eastern Standard time
    !rm /etc/localtime
    !ln -s /usr/share/zoneinfo/US/Eastern /etc/localtime
    
    %cd drive/Shareddrives/ROB_535_Group_18/perception/

else:
    print('Not running on CoLab')
    %cd /home/benjamin/project

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

Not running on CoLab
/home/benjamin/project


Load relevant libraries

In [2]:
import argparse
import math
import os
import random
import sys
import time
import glob
import yaml

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import SGD, Adam, lr_scheduler
from torchvision.ops import FeaturePyramidNetwork
from torchvision import transforms
from torchvision import models
import tqdm
from PIL import Image
from swin_transformer_pytorch import SwinTransformer
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

#!pip install git+https://github.com/facebookresearch/fvcore.git
#!pip install swin-transformer-pytorch
#!pip install ipywidgets

Dataloaders

In [3]:
#############################################
dataset_folder = "datasets/gtacar"
batch_size = 8
num_workers = 4

transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float),
    transforms.CenterCrop((638,1914)),
    transforms.Resize((256,768)),
    # transforms.Resize((224,672)),
    # transforms.Resize((281,512)),
    # transforms.Pad((0,115,0,116))
])
#############################################

train_images = "images/train"
train_labels = "labels/train"
val_images = "images/val"
val_labels = "labels/val"

class ImageDataset(Dataset):
    def __init__(self, dataset_folder, image_path, label_path, transform=None, nc=3):
        super().__init__()
        self.dataset_folder = dataset_folder
        self.image_filenames = glob.glob(os.path.join(dataset_folder, image_path, "*"))
        self.label_path = label_path
        self.transform = transform
        self.nc = nc

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        X = transforms.functional.pil_to_tensor(Image.open(self.image_filenames[idx]))
        label_filename = self.image_filenames[idx].split("/")[-1].split(".")[0] + ".txt"
        label_path = os.path.join(self.dataset_folder, self.label_path, label_filename)

        with open(label_path) as f:
            line_data = f.readlines()[0].split(" ")
            label = torch.tensor(int(line_data[0]),dtype=torch.int64)
            bbox = [float(x) for x in line_data[1:]]
        # y = torch.zeros((3,))
        # y[label] = 1

        mask = torch.zeros(X.shape[1:])
        h,w = X.shape[1:]
        mask[int(h*(bbox[1]-bbox[3])):int(h*(bbox[1]+bbox[3])),int(w*(bbox[0]-bbox[2])):int(w*(bbox[0]+bbox[2]))] = 1
        mask = mask.unsqueeze(0)

        if not self.transform is None:
            X = self.transform(X)
            mask = self.transform(mask)
        
        mask = mask.squeeze(0)
        mask = mask.numpy()
        cols, rows = np.nonzero(mask)
        if len(cols)==0: 
            bbox = torch.zeros(4, dtype=torch.float32)
        else:
            top_row = np.min(rows)
            left_col = np.min(cols)
            bottom_row = np.max(rows)
            right_col = np.max(cols)
            bbox = torch.tensor([left_col, top_row, right_col, bottom_row], dtype=torch.float32)

        # im = X.squeeze().permute(1,2,0).numpy()
        # plt.imshow(im)
        # plt.show()
        # plt.imshow(mask)
        # plt.show()

        return X, label, bbox
    
    def get_class_weights(self):
        weights = torch.zeros((self.nc,))
        for filename in self.image_filenames:
            label_filename = filename.split("/")[-1].split(".")[0] + ".txt"
            label_path = os.path.join(self.dataset_folder, self.label_path, label_filename)
            
            with open(label_path) as f:
                y = torch.tensor(int(f.readlines()[0].split(" ")[0])-1).to(torch.int64)

            weights[y] += 1
        
        nmax = torch.max(weights)
        weights = nmax / weights
        return weights

tr_dataset = ImageDataset(dataset_folder, train_images, train_labels, transform)
weights = tr_dataset.get_class_weights().to(device)
print(len(tr_dataset))
tr_dataset = DataLoader(tr_dataset, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_dataset = ImageDataset(dataset_folder, val_images, val_labels, transform)
print(len(val_dataset))
val_dataset = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)



6058
1515


Model

In [4]:
############################
lr = 1e-3
############################

#model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest269', pretrained=False, dilated=False, num_classes=3).to(device)
# model = SwinTransformer(
#     hidden_dim=96,
#     layers=(2, 2, 6, 2),
#     heads=(3, 6, 12, 24),
#     channels=3,
#     num_classes=3,
#     head_dim=32,
#     window_size=8,
#     downscaling_factors=(4, 2, 2, 2),
#     relative_pos_embedding=True
# ).to(device)
# model = SwinTransformer(
#     hidden_dim=96,
#     layers=(2, 2, 6, 2),
#     heads=(3, 6, 12, 24),
#     channels=3,
#     num_classes=3,
#     head_dim=32,
#     window_size=7,
#     downscaling_factors=(4, 2, 2, 2),
#     relative_pos_embedding=True
# ).to(device)

class BBReg_SwinTransformer(nn.Module):
    def __init__(self):
        super(BBReg_SwinTransformer, self).__init__()
        swin_transformer = SwinTransformer(
            hidden_dim=96,
            layers=(2, 2, 6, 2),
            heads=(3, 6, 12, 24),
            channels=3,
            num_classes=3,
            head_dim=32,
            window_size=8,
            downscaling_factors=(4, 2, 2, 2),
            relative_pos_embedding=True
        )
        self.swin_transformer = nn.Sequential(*list(swin_transformer.children())[:-1])
        
        # Prediction head
        self.clf = nn.Sequential(
            nn.BatchNorm1d(768),
            nn.Linear(768,128),
            nn.ReLU(),
            nn.Linear(128,3),
            nn.Softmax()
        )
        self.bboxreg = nn.Sequential(
            nn.BatchNorm1d(768),
            nn.Linear(768,4)
        )


    def forward(self, x):
        x = self.swin_transformer(x)
        x = x.mean(dim=[2, 3])
        cls = self.clf(x)
        bbox = self.bboxreg(x) 
        return cls, bbox

# model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True)
# in_features = model.roi_heads.box_predictor.cls_score.in_features
# model.roi_heads.box_predictor = models.detection.faster_rcnn.FastRCNNPredictor(in_features, 2)
# model = model.to(device)
model = BBReg_SwinTransformer().to(device)
criterion = torch.nn.CrossEntropyLoss(weight=weights)
bbox_criterion = torch.nn.L1Loss()
# criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# Test forward pass
# for X, Y, B in tr_dataset:
#     # X = list(x.to(device) for x in X)
#     # targets = [{"boxes" : b.unsqueeze(0).to(device), "labels" : y.unsqueeze(0).to(torch.int64).to(device)} if y == torch.tensor(-1) else {"boxes" : torch.zeros((0,4)).to(device), "labels" : torch.zeros((0,1)).to(torch.int64).to(device)} for b,y in zip(B,Y)]
#     # print(targets[1]["labels"].dtype)
#     # print(targets)
#     X = X.to(device)
#     outputs = model(X)
#     # outputs = model(X,targets)
#     print(outputs)
#     break

Train

In [5]:
######################################
num_epochs = 50
######################################

def train_epoch(data_loader, model, criterion, optimizer, device):
    """
    Train the `model` for one epoch of data from `data_loader`
    Use `optimizer` to optimize the specified `criterion`
    """
    
    model = model.train()

    correct, total = 0, 0
    running_loss = []

    for X, Y, B in tqdm.tqdm(data_loader):
        X = X.to(device)
        Y = Y.to(device)
        B = B.to(device)
        # X = list(x.to(device) for x in X)
        # targets = [{"boxes" : b.unsqueeze(0).to(device), "labels" : y.unsqueeze(0).to(device)} if not torch.all(b == torch.zeros((1,4))) else {"boxes" : torch.zeros((0,4)).to(device), "labels" : torch.zeros((0,1)).to(device)} for b,y in zip(B,Y)]
        # targets = [{"boxes" : b.unsqueeze(0).to(device), "labels" : y.unsqueeze(0).to(torch.int64).to(device)} if y == torch.tensor(-1) else {"boxes" : torch.zeros((0,4)).to(device), "labels" : torch.zeros((0,1)).to(torch.int64).to(device)} for b,y in zip(B,Y)]
        # clear parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        output, bbox = model(X)
        loss = criterion(output, Y) + bbox_criterion(bbox,B) / 1000
        loss.backward()
        optimizer.step()

        # loss_dict = model(X, targets)
        # loss = sum(loss for loss in loss_dict.values())

        predicted = torch.argmax(output, dim=1)
        total += Y.shape[0]
        correct += (predicted == Y).sum().item()

        # calculate loss and accuracy
        running_loss.append(loss.item())

    train_loss = np.mean(running_loss)
    train_acc = correct / total

    return train_loss, train_acc

def evaluate_epoch(val_loader, model, criterion, device):
    model.eval()

    with torch.no_grad():
        Y_true, Y_pred = [], []
        correct, total = 0, 0
        running_loss = []
        for X, Y, B in tqdm.tqdm(val_loader):
            X = X.to(device)
            Y = Y.to(device)

            output, bbox = model(X)
            predicted = torch.argmax(output, dim=1)
            
            total += Y.shape[0]
            correct += (predicted == Y).sum().item()

            # loss calculation
            running_loss.append(criterion(output, Y).item())
            
        val_loss = np.mean(running_loss)
        val_acc = correct / total

    return val_loss, val_acc

def train(model, criterion, optimizer, tr_loader, val_loader, device):
    stats = {"epoch" : [], "tr_acc" : [], "tr_loss" : [], "val_acc" : [], "val_loss" : []}
    best_val_acc = 0.64
    for epoch in range(1, num_epochs+1):
        print("Epoch {}:".format(epoch))
        stats["epoch"].append(epoch)

        # Train model
        tr_loss, tr_acc = train_loss, train_acc = train_epoch(tr_loader, model, criterion, optimizer, device)

        print("Train loss = {}, train accuracy = {}".format(tr_loss, tr_acc))
        stats["tr_acc"].append(tr_acc)
        stats["tr_loss"].append(tr_loss)

        # Evaluate model
        val_loss, val_acc = evaluate_epoch(val_loader, model, criterion, device)

        print("Validation loss = {}, validation accuracy = {}".format(val_loss, val_acc))
        stats["val_acc"].append(val_acc)
        stats["val_loss"].append(val_loss)

        if val_acc > best_val_acc:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
            }, "epoch{}.valacc{}.pth".format(epoch, val_acc))
            best_val_acc = val_acc


stats = train(model, criterion, optimizer, tr_dataset, val_dataset, device)


Epoch 1:


100%|█████████████████████████████████████████| 758/758 [03:47<00:00,  3.33it/s]


Train loss = 1.1247287253433922, train accuracy = 0.47028722350610763


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.11it/s]


Validation loss = 0.9777270461383619, validation accuracy = 0.29108910891089107
Epoch 2:


100%|█████████████████████████████████████████| 758/758 [03:50<00:00,  3.29it/s]


Train loss = 1.0084103127268185, train accuracy = 0.4853086827335754


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.09it/s]


Validation loss = 0.9006045821465944, validation accuracy = 0.4389438943894389
Epoch 3:


100%|█████████████████████████████████████████| 758/758 [03:43<00:00,  3.39it/s]


Train loss = 1.0167185017804672, train accuracy = 0.4915813799933972


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.86it/s]


Validation loss = 0.9513887719104165, validation accuracy = 0.35181518151815183
Epoch 4:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.49it/s]


Train loss = 0.9745082603595502, train accuracy = 0.5815450643776824


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.84it/s]


Validation loss = 0.8573498656875209, validation accuracy = 0.66996699669967
Epoch 5:


100%|█████████████████████████████████████████| 758/758 [03:41<00:00,  3.42it/s]


Train loss = 0.9706829177672756, train accuracy = 0.6153846153846154


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.11it/s]


Validation loss = 0.9783403311905108, validation accuracy = 0.29504950495049503
Epoch 6:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9575711859090347, train accuracy = 0.6099372730274017


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.15it/s]


Validation loss = 0.8473713765018864, validation accuracy = 0.666006600660066
Epoch 7:


100%|█████████████████████████████████████████| 758/758 [03:50<00:00,  3.28it/s]


Train loss = 0.9741913599986829, train accuracy = 0.5845163420270716


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.15it/s]


Validation loss = 0.9029841608122775, validation accuracy = 0.693069306930693
Epoch 8:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.27it/s]


Train loss = 0.9568357491241596, train accuracy = 0.6254539451964345


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.13it/s]


Validation loss = 1.1153302167591295, validation accuracy = 0.6336633663366337
Epoch 9:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9539416561340593, train accuracy = 0.6214922416639155


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.16it/s]


Validation loss = 0.8478110752607647, validation accuracy = 0.5927392739273928
Epoch 10:


100%|█████████████████████████████████████████| 758/758 [03:47<00:00,  3.33it/s]


Train loss = 0.9499886385840917, train accuracy = 0.628260151865302


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.80it/s]


Validation loss = 1.0780026915826295, validation accuracy = 0.6481848184818482
Epoch 11:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9481590966593307, train accuracy = 0.6206668867613073


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.82it/s]


Validation loss = 0.9810612408738387, validation accuracy = 0.29108910891089107
Epoch 12:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.941995385611592, train accuracy = 0.6313965004952129


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.86it/s]


Validation loss = 0.8487045344553495, validation accuracy = 0.5821782178217821
Epoch 13:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9526637082521393, train accuracy = 0.61687025420931


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.80it/s]


Validation loss = 0.9810633706419092, validation accuracy = 0.29108910891089107
Epoch 14:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9471919568830556, train accuracy = 0.6198415318586993


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.85it/s]


Validation loss = 1.0455583716693677, validation accuracy = 0.6534653465346535
Epoch 15:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9489647385941961, train accuracy = 0.6190161769560911


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.23it/s]


Validation loss = 0.9656679197361595, validation accuracy = 0.29108910891089107
Epoch 16:


100%|█████████████████████████████████████████| 758/758 [03:52<00:00,  3.26it/s]


Train loss = 0.9852108386858787, train accuracy = 0.5407725321888412


100%|█████████████████████████████████████████| 190/190 [00:21<00:00,  9.04it/s]


Validation loss = 0.9078038705022712, validation accuracy = 0.3716171617161716
Epoch 17:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9561338376715819, train accuracy = 0.6107626279300099


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.06it/s]


Validation loss = 0.9235101938247681, validation accuracy = 0.7056105610561056
Epoch 18:


100%|█████████████████████████████████████████| 758/758 [03:50<00:00,  3.28it/s]


Train loss = 0.95983391378989, train accuracy = 0.5952459557609773


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.13it/s]


Validation loss = 0.9389752287613718, validation accuracy = 0.35181518151815183
Epoch 19:


100%|█████████████████████████████████████████| 758/758 [03:49<00:00,  3.31it/s]


Train loss = 0.9449299984368297, train accuracy = 0.6219874546054803


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.84it/s]


Validation loss = 0.9590741022637016, validation accuracy = 0.32607260726072607
Epoch 20:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9568773777786849, train accuracy = 0.6061406404754044


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.87it/s]


Validation loss = 0.9783859362727717, validation accuracy = 0.29108910891089107
Epoch 21:


100%|█████████████████████████████████████████| 758/758 [03:48<00:00,  3.31it/s]


Train loss = 0.94799134072339, train accuracy = 0.6003631561571475


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.06it/s]


Validation loss = 0.9442605602113824, validation accuracy = 0.2996699669966997
Epoch 22:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9365131299067928, train accuracy = 0.6140640475404424


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.11it/s]


Validation loss = 0.8457176904929312, validation accuracy = 0.7016501650165017
Epoch 23:


100%|█████████████████████████████████████████| 758/758 [03:50<00:00,  3.28it/s]


Train loss = 0.9433346746936637, train accuracy = 0.6072961373390557


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.09it/s]


Validation loss = 0.8963912104305468, validation accuracy = 0.4904290429042904
Epoch 24:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9345712888209361, train accuracy = 0.6328821393199076


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.09it/s]


Validation loss = 0.9081467779059159, validation accuracy = 0.4099009900990099
Epoch 25:


100%|█████████████████████████████████████████| 758/758 [03:50<00:00,  3.28it/s]


Train loss = 0.9562151895192179, train accuracy = 0.6056454275338395


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.12it/s]


Validation loss = 0.9619242040734542, validation accuracy = 0.6838283828382838
Epoch 26:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9525154628508323, train accuracy = 0.606305711455926


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.11it/s]


Validation loss = 0.8395660519599915, validation accuracy = 0.568976897689769
Epoch 27:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.27it/s]


Train loss = 0.9391054037692993, train accuracy = 0.6117530538131396


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.07it/s]


Validation loss = 0.901531109684392, validation accuracy = 0.7148514851485148
Epoch 28:


100%|█████████████████████████████████████████| 758/758 [03:44<00:00,  3.37it/s]


Train loss = 0.9318145222588391, train accuracy = 0.629580719709475


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.81it/s]


Validation loss = 0.8188137644215634, validation accuracy = 0.6237623762376238
Epoch 29:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9379743460300415, train accuracy = 0.6148894024430505


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.85it/s]


Validation loss = 0.8403481844224427, validation accuracy = 0.5590759075907591
Epoch 30:


100%|█████████████████████████████████████████| 758/758 [03:40<00:00,  3.44it/s]


Train loss = 0.9326812402396844, train accuracy = 0.6257840871574777


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.12it/s]


Validation loss = 0.9297901015532645, validation accuracy = 0.3122112211221122
Epoch 31:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9333154570930552, train accuracy = 0.6147243314625289


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.09it/s]


Validation loss = 0.9367628853572042, validation accuracy = 0.3405940594059406
Epoch 32:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.27it/s]


Train loss = 0.9284016676345412, train accuracy = 0.6206668867613073


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.07it/s]


Validation loss = 0.8305934121734218, validation accuracy = 0.5900990099009901
Epoch 33:


100%|█████████████████████████████████████████| 758/758 [03:50<00:00,  3.28it/s]


Train loss = 0.9311745529281747, train accuracy = 0.624793661274348


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.06it/s]


Validation loss = 1.145838709881431, validation accuracy = 0.6343234323432343
Epoch 34:


100%|█████████████████████████████████████████| 758/758 [03:44<00:00,  3.38it/s]


Train loss = 0.9377710652540101, train accuracy = 0.6021789369428855


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.79it/s]


Validation loss = 0.8542506108158513, validation accuracy = 0.563036303630363
Epoch 35:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9362101396816073, train accuracy = 0.6355232750082536


100%|█████████████████████████████████████████| 190/190 [00:21<00:00,  9.04it/s]


Validation loss = 1.0553302357071324, validation accuracy = 0.6495049504950495
Epoch 36:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.27it/s]


Train loss = 0.9287036428508154, train accuracy = 0.6284252228458237


100%|█████████████████████████████████████████| 190/190 [00:21<00:00,  8.78it/s]


Validation loss = 1.056296315946077, validation accuracy = 0.662046204620462
Epoch 37:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9319890711584318, train accuracy = 0.6114229118520964


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.08it/s]


Validation loss = 0.8165407412930539, validation accuracy = 0.6633663366336634
Epoch 38:


100%|█████████████████████████████████████████| 758/758 [03:50<00:00,  3.28it/s]


Train loss = 0.9294706737145899, train accuracy = 0.6044899306701882


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.12it/s]


Validation loss = 1.1151555980506695, validation accuracy = 0.6402640264026402
Epoch 39:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.28it/s]


Train loss = 0.9322183073196059, train accuracy = 0.6109276989105316


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.06it/s]


Validation loss = 0.8371621941265307, validation accuracy = 0.5927392739273928
Epoch 40:


100%|█████████████████████████████████████████| 758/758 [03:47<00:00,  3.33it/s]


Train loss = 0.9310273452610328, train accuracy = 0.6239683063717398


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.82it/s]


Validation loss = 1.1458313603150216, validation accuracy = 0.6343234323432343
Epoch 41:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9601359677975285, train accuracy = 0.5609111918124794


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.85it/s]


Validation loss = 0.8592874837549109, validation accuracy = 0.48184818481848185
Epoch 42:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9507441125161415, train accuracy = 0.5863321228128096


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.82it/s]


Validation loss = 0.8427387585765437, validation accuracy = 0.6937293729372938
Epoch 43:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9417567608538907, train accuracy = 0.6084516342027072


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.81it/s]


Validation loss = 0.8327942848205566, validation accuracy = 0.671947194719472
Epoch 44:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.48it/s]


Train loss = 0.9505069889619672, train accuracy = 0.5998679432155827


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.87it/s]


Validation loss = 0.9218157049856688, validation accuracy = 0.3458745874587459
Epoch 45:


100%|█████████████████████████████████████████| 758/758 [03:46<00:00,  3.35it/s]


Train loss = 0.9471250958839004, train accuracy = 0.6018487949818422


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.09it/s]


Validation loss = 1.1288501711268173, validation accuracy = 0.636963696369637
Epoch 46:


100%|█████████████████████████████████████████| 758/758 [03:50<00:00,  3.29it/s]


Train loss = 0.9368256518425602, train accuracy = 0.632221855397821


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.07it/s]


Validation loss = 0.9810634766754351, validation accuracy = 0.29108910891089107
Epoch 47:


100%|█████████████████████████████████████████| 758/758 [03:51<00:00,  3.27it/s]


Train loss = 0.937200505768089, train accuracy = 0.6464179597226808


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.14it/s]


Validation loss = 0.9047683424071262, validation accuracy = 0.7095709570957096
Epoch 48:


100%|█████████████████████████████████████████| 758/758 [03:50<00:00,  3.28it/s]


Train loss = 0.9323500695203414, train accuracy = 0.6342027071640806


100%|█████████████████████████████████████████| 190/190 [00:20<00:00,  9.14it/s]


Validation loss = 0.9809532956073158, validation accuracy = 0.29108910891089107
Epoch 49:


100%|█████████████████████████████████████████| 758/758 [03:38<00:00,  3.47it/s]


Train loss = 0.9315536108683785, train accuracy = 0.6216573126444371


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.87it/s]


Validation loss = 0.9772003296174501, validation accuracy = 0.2957095709570957
Epoch 50:


100%|█████████████████████████████████████████| 758/758 [03:37<00:00,  3.49it/s]


Train loss = 0.9273267193331253, train accuracy = 0.6307362165731264


100%|█████████████████████████████████████████| 190/190 [00:19<00:00,  9.82it/s]

Validation loss = 1.019099130442268, validation accuracy = 0.6739273927392739





Inference

In [5]:
def inference(model, folder_path, output_path, transform, device):
    with open(output_path, "w") as f:
        f.write("guid/image,label\n")

        model.eval()
        image_paths = glob.glob(os.path.join(folder_path, "*"))
        for image_path in tqdm.tqdm(image_paths):
            im = transforms.functional.pil_to_tensor(Image.open(image_path))
            im = transform(im).to(device)
            im = im.unsqueeze(0)
            pred, bbox = model(im)
            prediction = torch.argmax(pred).item()

            guid = image_path.split("/")[-1].split(".")[0]
            im_no = image_path.split("/")[-1].split(".")[1]

            f.write(guid + "/" + im_no + "," + str(prediction) + "\n")

checkpoint = torch.load("epoch27.valacc0.7148514851485148.pth")
model.load_state_dict(checkpoint['model_state_dict'])   
inference(model, os.path.join(dataset_folder,"images/test"), "output.txt", transform, device)

100%|███████████████████████████████████████| 2631/2631 [01:31<00:00, 28.83it/s]


In [6]:
def thresh_evaluate_epoch(val_loader, model, criterion, thresh, device):
    model.eval()

    with torch.no_grad():
        Y_true, Y_pred = [], []
        correct, total = 0, 0
        running_loss = []
        for X, Y, B in tqdm.tqdm(val_loader):
            X = X.to(device)
            Y = Y.to(device)
            print(Y)

            output, bbox = model(X)
            predicted = torch.argmax(output, dim=1)
            print(output)
            maximums = torch.max(output,dim=1)
            break
            predicted[torch.max(output,dim=1) < torch.ones_like(predicted)*thresh] = 0
            
            total += Y.shape[0]
            correct += (predicted == Y).sum().item()

            # loss calculation
            running_loss.append(criterion(output, Y).item())
            
        val_loss = np.mean(running_loss)
        val_acc = correct / total

    return val_loss, val_acc

print(thresh_evaluate_epoch(val_dataset, model, criterion, 0.4, device))

  0%|                                                   | 0/190 [00:00<?, ?it/s]

tensor([1, 1, 1, 1, 1, 1, 2, 1], device='cuda:0')
tensor([[3.0121e-08, 3.4171e-01, 6.5829e-01],
        [2.9377e-09, 2.4324e-01, 7.5676e-01],
        [1.8586e-06, 6.9148e-01, 3.0852e-01],
        [2.0720e-31, 4.8991e-06, 1.0000e+00],
        [0.0000e+00, 1.0000e+00, 7.2527e-09],
        [5.8015e-40, 9.9982e-01, 1.8027e-04],
        [0.0000e+00, 1.0000e+00, 1.2087e-07],
        [2.9317e-27, 9.9686e-01, 3.1419e-03]], device='cuda:0')


  0%|                                                   | 0/190 [00:00<?, ?it/s]


ZeroDivisionError: ignored