In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
dataroot='/content/drive/MyDrive/'
# !unzip /content/drive/MyDrive/images_zip.zip -d "/content/drive/My Drive/585_data"
# !unzip /content/drive/MyDrive/ground_truth_zip.zip -d "/content/drive/My Drive/585_data"
IS_GPU = True
TEST_BS = 2
TRAIN_BS = 4

In [None]:

from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

from torchvision.datasets.utils import download_url, check_integrity

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
import glob
import cv2
import torch
import torchvision.utils as vutils
import torchvision.models as models
import torch.optim as optim
import scipy.io
from torch.utils.data import Dataset

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

dataroot='/content/drive/MyDrive/585_data/alt_images/'


train_transform = transforms.Compose(
    [
     transforms.CenterCrop((224,224)),
     transforms.RandomHorizontalFlip(),
     transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
     transforms.ToTensor(),
     ])

test_transform = transforms.Compose(
    [
     transforms.CenterCrop((224,224)),
     transforms.ToTensor(),
    ])


class GANDataset(Dataset):
    def __init__(self, root_dir, split, transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        
        self.img_folder = os.path.join(self.root_dir, "images/", self.split)
        self.gt_folder = os.path.join(self.root_dir, "ground_truth/", self.split)
  
    
        self.img_filenames = sorted(os.listdir(self.img_folder))
        self.gt_filenames = sorted(os.listdir(self.gt_folder))

    def __len__(self):
        return len(self.img_filenames)
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_folder, self.img_filenames[idx])
        gt_path = os.path.join(self.gt_folder, self.gt_filenames[idx])
        
        img = Image.open(img_path)
        gt = Image.open(gt_path).convert('RGB')
       
        
        if self.transform:
            img = self.transform(img)
            gt=self.transform(gt)
        
        return img, gt

dataset=GANDataset(dataroot,split='train',transform=train_transform)
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")


dataloader = torch.utils.data.DataLoader(dataset, batch_size=TRAIN_BS,
                                         shuffle=True, num_workers=0)

val_dataset=GANDataset(dataroot,split='val',transform=train_transform)

val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=TEST_BS,
                                         shuffle=True, num_workers=0)



In [None]:
#from https://github.com/captanlevi/Contour-Detection-Pytorch
class Encoder(nn.Module):
    def __init__(self, vgg):
        super().__init__()
        self.vgg = list(vgg.children())
        self.vgg=self.vgg[0]
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.conv6 = nn.Conv2d(in_channels = 512, out_channels = 4096, kernel_size = 3, stride=1, padding = 1)
    def forward(self,x):

        pooling_info = {}
        layer_info = {}
        # Starting conv1
        x = self.vgg[0](x)
 
        x = self.vgg[1](x)
        x = self.vgg[2](x)
        x = self.vgg[3](x)
        shape = x.shape
        
        layer_info[1] = {"value": x}
        x , ind = self.pool1(x)
        pooling_info[1] = {"kernel_size" : 2, "stride": 2, "padding": 0 ,"output_size": shape,"indices":ind}
        

        # start conv2
        x = self.vgg[5](x)
        x = self.vgg[6](x)
        x = self.vgg[7](x)
        x = self.vgg[8](x)

        shape = x.shape
        layer_info[2] = {"value": x}
        x , ind = self.pool2(x)
        pooling_info[2] = {"kernel_size" : 2, "stride": 2, "padding": 0 ,"output_size": shape,"indices":ind}



        # start conv3
        x = self.vgg[10](x)
        x = self.vgg[11](x)
        x = self.vgg[12](x)
        x = self.vgg[13](x)
        x = self.vgg[14](x)
        x = self.vgg[15](x)

        shape = x.shape
        layer_info[3] = {"value": x}
        x , ind = self.pool3(x)
        pooling_info[3] = {"kernel_size" : 2, "stride": 2, "padding": 0 ,"output_size": shape,"indices":ind}
  

        x = self.vgg[17](x)
        x = self.vgg[18](x)
        x = self.vgg[19](x)
        x = self.vgg[20](x)
        x = self.vgg[21](x)
        x = self.vgg[22](x)


        shape = x.shape
        layer_info[4] = {"value": x}
        x , ind = self.pool4(x)
        pooling_info[4] = {"kernel_size" : 2, "stride": 2, "padding": 0 ,"output_size": shape,"indices":ind}
      


        x = self.vgg[24](x)
        x = self.vgg[25](x)
        x = self.vgg[26](x)
        x = self.vgg[27](x)
        x = self.vgg[28](x)
        x = self.vgg[29](x)

        shape = x.shape
        layer_info[5] = {"value": x}
        x , ind = self.pool5(x)
        pooling_info[5] = {"kernel_size" : 2, "stride": 2, "padding": 0 ,"output_size": shape,"indices":ind}
    
        x = self.conv6(x)

        
        return x , pooling_info, layer_info
         

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.dconv6 = nn.Conv2d(in_channels = 4096, out_channels = 512, kernel_size = 1, stride=1)
       
        self.deconv5 = nn.ConvTranspose2d(in_channels = 512, out_channels = 512, kernel_size = 5, padding =2)
        self.deconv4 = nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size = 5 , padding = 2)
        self.deconv3 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size = 5 ,padding = 2)
        self.deconv2 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = 5 , padding = 2)
        self.deconv1 = nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = 5 ,padding = 2)
        self.pred = nn.ConvTranspose2d(in_channels = 32, out_channels =3, kernel_size = 5, padding = 2)

    def forward(self,encoder_out):
        x = encoder_out[0]
        dicts = encoder_out[1]


        x = self.dconv6(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_unpool2d(x, **dicts[5])

        
     
        x = self.deconv5(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_unpool2d(x, **dicts[4])  # Indices 512


        x = self.deconv4(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_unpool2d(x, **dicts[3])  # Indices 256


        x = self.deconv3(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_unpool2d(x, **dicts[2])  # Indices 128


        x = self.deconv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_unpool2d(x, **dicts[1])  # Indices 64
        x = self.deconv1(x)
        x = nn.functional.relu(x)

        x = self.pred(x)

        x = torch.sigmoid(x)
        return x
    
class countour_detector(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.encoder = Encoder(backbone.to(device)).to(device)
        self.decoder = Decoder().to(device)

    def forward(self,x):
        x = self.encoder(x)
        return self.decoder(x)


vgg16 =  torchvision.models.vgg16(pretrained=True).to(device)

for p in vgg16.parameters():
    p.requires_grad = False


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:01<00:00, 310MB/s]


In [None]:
#From PS2 template, and from https://github.com/captanlevi/Contour-Detection-Pytorch
generator = countour_detector(backbone=vgg16).to(device)
optimizer = torch.optim.Adam([x for x in list(generator.parameters()) if x.requires_grad == True], lr=.0001,weight_decay=1e-4)
def context_loss(outputs, targets):
        weights = torch.empty_like(targets).to(device)
        weights[targets >= .98] = 10
        weights[targets < .98] = 1
        loss = F.binary_cross_entropy(outputs, targets, weights)
        return loss 

best_loss = float('inf')
num_epochs = 15
best_model_state_dict = None
for epoch in range(num_epochs):
    running_loss=0
    for batch_idx, (real_samples, labels) in enumerate(dataloader):
        real_samples=real_samples.to(device)
        labels=labels.to(device)
        # Train the generator
        generator.zero_grad()
        output=generator(real_samples)
        g_loss=context_loss(output,labels)
        g_loss.backward()
        optimizer.step()
        running_loss += g_loss.item()
    
  
        if batch_idx % 10 == 0:
            print("Epoch [{}/{}], Batch [{}/{}],  G_loss: {:.4f}".format(
                epoch+1, num_epochs, batch_idx+1, len(dataloader), g_loss.item()))
        
        if batch_idx % 200==0:
          with torch.no_grad():
                val_loss = 0
                for batch_idx, (data, target) in enumerate(val_dataloader):
                    # print(len(val_dataloader))
                    data=data.to(device)
                    target=target.to(device)
                    output = generator(data)
                    val_loss += context_loss(output, target)
                  
                val_loss /= len(val_dataloader)
                print(val_loss)
                if val_loss < best_loss:
                      best_loss = val_loss
                      best_model_state_dict = generator.state_dict()
                      torch.save(best_model_state_dict, dataroot+'best_enc_dec_model.pt')
            


KeyboardInterrupt: ignored

In [None]:
my_model = countour_detector(backbone=vgg16).to(device)
my_model.load_state_dict(torch.load(dataroot+'best_enc_dec_model.pt'))

def loss(outputs, targets):
        weights = torch.empty_like(targets).to(device)
        weights[targets >= .98] = 10
        weights[targets < .98] = 1
        loss = F.binary_cross_entropy(outputs, targets, weights)
        return loss 


def simple_predict(model):
    model.eval()

    dataset = GANDataset(dataroot, split='test', transform=train_transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4,
                                             shuffle=False, num_workers=0)

    with torch.no_grad():
        for batch_idx, (real_samples, labels) in enumerate(dataloader):
            if batch_idx >1:
                break
            else:
                real_samples = real_samples.to(device)
                output = model(real_samples)

                real_samples=real_samples.cpu().detach().numpy()[0]
                output = output.cpu().detach().numpy()[0]
                label = labels.cpu().detach().numpy()[0]
                
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
                ax1.axis("off")
                ax1.set_title("Image")
                ax1.imshow(real_samples.transpose((1, 2, 0)))

                ax2.axis("off")
                ax2.set_title("Model Output")
                ax2.imshow(output.transpose((1, 2, 0)))

                # print(g_loss)
                plt.show()

                
simple_predict(my_model)

NameError: ignored

In [None]:
my_model = countour_detector(backbone=vgg16).to(device)
my_model.load_state_dict(torch.load(dataroot+'best_enc_dec_model.pt'))

              
def avg_test_loss(model):
    model.eval()
    test_loss = 0
    num_samples = 0
    
    dataset = GANDataset(dataroot, split='test', transform=train_transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                             shuffle=True, num_workers=0)

    with torch.no_grad():
        for batch_idx, (real_samples, labels) in enumerate(dataloader):
            real_samples = real_samples.to(device)
            labels = labels.to(device)
            output = model(real_samples)
            batch_loss = loss(output, labels)
            test_loss += batch_loss.item()
            num_samples += 1

    avg_loss = test_loss / num_samples
    return avg_loss

avg_test_loss(my_model)