In [None]:
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch
import torchvision
from glob import glob
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as transform
from torch.utils.data import DataLoader,Dataset

import sys
sys.path.append("/Users/nathanieljames/Desktop/dl_final_proj/dl_final_proj")
from utils.cityscapes import CityscapesDataset, get_loaders

In [None]:
train_loader, test_loader = get_loaders(batch_size = 8, subclass = '5')
print(len(train_loader), len(test_loader))

In [None]:
dtype = torch.float
device = torch.device("mps")

In [None]:
# set device to cuda if it's available
if torch.cuda.is_available():
    print('using cuda')
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
# mount drive in Colab
from google.colab import drive
drive.mount('/content/drive')
# only use if on Colab obv

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.res1 = ResNetBlock(64, 64, 3, 1, 1)
        self.res2 = ResNetBlock(64, 64, 3, 1, 1)
        self.res3 = ResNetBlock(64, 64, 3, 1, 1)
        self.res4 = ResNetBlock(64, 64, 3, 1, 1)
        self.res5 = ResNetBlock(64, 64, 3, 1, 1)
        self.res6 = ResNetBlock(64, 64, 3, 1, 1)
        self.res7 = ResNetBlock(64, 64, 3, 1, 1)
        self.res8 = ResNetBlock(64, 64, 3, 1, 1)
        self.res9 = ResNetBlock(64, 64, 3, 1, 1)
        self.res10 = ResNetBlock(64, 64, 3, 1, 1)
        self.res11 = ResNetBlock(64, 64, 3, 1, 1)
        self.res12 = ResNetBlock(64, 64, 3, 1, 1)
        self.res13 = ResNetBlock(64, 64, 3, 1, 1)
        self.res14 = ResNetBlock(64, 64, 3, 1, 1)
        self.res15 = ResNetBlock(64, 64, 3, 1, 1)
        self.res16 = ResNetBlock(64, 64, 3, 1, 1)
        self.res17 = ResNetBlock(64, 64, 3, 1, 1)
        self.res18 = ResNetBlock(64, 64, 3, 1, 1)
        # turn back int 3x256x256
        self.conv2 = nn.Conv2d(64, 3, 3, 1, 1)



    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.res1(out)
        out = self.res2(out)
        out = self.res3(out)
        out = self.res4(out)
        out = self.res5(out)
        out = self.res6(out)
        out = self.res7(out)
        out = self.res8(out)
        out = self.res9(out)
        out = self.res10(out)
        out = self.res11(out)
        out = self.res12(out)
        out = self.res13(out)
        out = self.res14(out)
        out = self.res15(out)
        out = self.res16(out)
        out = self.res17(out)
        out = self.res18(out)
        out = self.conv2(out)
        return out

In [None]:
model = ResNet().float()

from torchsummary import summary
summary(model, (3,256,256))
model = model.to(device)

epochs = 25

loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

train_acc = []
val_acc = []
train_loss = []
val_loss = []

In [None]:
best_loss = float('inf')
for i in range(epochs):

    trainloss = 0
    valloss = 0

    for img,label in tqdm(train_loader):
        #print("new image")
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device)
        output = model(img)
        loss = loss_func(output,label)
        loss.backward()
        optimizer.step()
        trainloss+=loss.item()

    train_loss.append(trainloss/len(train_loader))

    for img,label in tqdm(test_loader):
        img = img.to(device)
        label = label.to(device)
        output = model(img)
        loss = loss_func(output,label)
        valloss+=loss.item()

    epoch_loss = valloss/len(test_loader)
    val_loss.append(valloss/len(test_loader))
    print("epoch : {} ,train loss : {} ,valid loss : {} ".format(i,train_loss[-1],val_loss[-1]))
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        print(f"new best loss: {best_loss}")
        print("saving")
        torch.save(model.state_dict(), '/content/drive/MyDrive/direct_resnet_1.pth')