Import packes, load images and labels

In [None]:
import numpy as np
import matplotlib as mpl
#mpl.use('Agg')
import matplotlib.pyplot as plt
import glob
import os

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data as D
from sklearn.model_selection import train_test_split
from skimage.color import rgb2gray
from sklearn.preprocessing import OneHotEncoder
print(torch.__version__)
import time

def predict_image(image):
    image = image.reshape(-1,1,image.shape[0],image.shape[0])
    #print(image.shape)
    image_tensor = torch.from_numpy(image)
    input = torch.autograd.Variable(image_tensor)
    input = input.to(device,dtype=torch.float)
    output = model(input)
    return output.cpu().detach().numpy()


In [None]:
images = np.load('/home/jay/data/train_sig2/test_images.npy').reshape(images.shape[0],1,1024,1024)
labels = np.load('/home/jay/data/no_blur/test_labels.npy')

print('images: ',images.shape)
print('labels: ',labels.shape)

Create Torch Tensors to load data to model

In [None]:
tensor_im = torch.stack([torch.Tensor(i) for i in images])
tensor_lab = torch.stack([torch.Tensor(i) for i in labels])

dataset = torch.utils.data.TensorDataset(tensor_im,tensor_lab)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=8,shuffle=True)

Initialize CNN Architectures, set up for training

In [None]:
class Encoder_Decoder(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(1,8,20,padding=(10,10))
        self.conv2 = nn.Conv2d(8,16,10,padding=(5,5))
        self.conv3 = nn.Conv2d(16,32,3,padding=(1,1))
        self.conv4 = nn.Conv2d(32,32,3,padding=(1,1))
        self.conv5 = nn.Conv2d(32,16,3,padding=(1,1))
        self.conv6 = nn.Conv2d(16,8,3,padding=(1,1))
        self.conv7 = nn.Conv2d(8,2,3,padding=(1,1))
        self.conv8 = nn.Conv2d(2,2,1)#,padding=(1,1))
        self.pool = nn.MaxPool2d(2,2)
        self.upsample = nn.Upsample(scale_factor=2)
    def forward(self,x):
        x = F.relu(self.pool(self.conv1(x)))
        x = F.relu(self.pool(self.conv2(x)))
        x = F.relu(self.pool(self.conv3(x)))
        x = F.relu(self.upsample(self.conv4(x)))
        x = F.relu(self.upsample(self.conv5(x)))
        x = F.relu(self.upsample(self.conv6(x)))
        x = self.conv7(x)
        x = self.conv8(x)
        x = torch.reshape(x,(-1,2,512*512))
        return x
net = Net()

        
class UNet(nn.Module):
    def __init__(self):
        super(UNet,self).__init__()
        self.conv1 = nn.Conv2d(1,8,3,padding=(1,1))
        self.conv2 = nn.Conv2d(8,16,3,padding=(1,1))
        self.conv3 = nn.Conv2d(16,32,3,padding=(1,1))
        self.conv4 = nn.ConvTranspose2d(32,16,3,stride=2,padding=(1,1))
        self.conv5 = nn.ConvTranspose2d(48,8,3,stride=2,padding=(1,1))
        self.conv6 = nn.ConvTranspose2d(24,8,3,stride=2,padding=(1,1))
        self.conv7 = nn.Conv2d(16,2,3,padding=(1,1))
        self.conv8 = nn.Conv2d(2,2,1)#,padding=(1,1))
        self.pool = nn.MaxPool2d(2,2)
        #self.upsample = nn.Upsample(scale_factor=2)
        
    def forward(self,x):
        c1 = F.relu(self.conv1(x))
        c1out = self.pool(c1)
        c2 = F.relu(self.conv2(c1out))
        c2out = self.pool(c2)
        c3 = F.relu(self.conv3(c2out))
        c3out = self.pool(c3)
        c4 = F.relu(self.conv4(c3out,output_size=c3.size()))
        c4cat = torch.cat((c4,c3),dim=1)
        c5 = F.relu(self.conv5(c4cat,output_size=c2.size()))
        c5cat = torch.cat((c5,c2),dim=1)
        c6 = F.relu(self.conv6(c5cat,output_size=c1.size()))
        c6cat = torch.cat((c6,c1),dim=1)
        c7 = F.relu(self.conv7(c6cat))
        
        return c7


model = UNet()

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(),lr = 0.0001) 

scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10])


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
model.to(device)
### Use the following line only if training on multiple GPUs
model = nn.DataParallel(model)

### Use following section only if continuing training from a checkpoint
# checkpoint = torch.load('/home/jay/data/no_blur/norm_lr0001_OneStep/norm_lr0001_Decay10continue2_e205.pt')
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# for state in optimizer.state.values():
#     for k,v in state.items():
#         if isinstance(v,torch.Tensor):
#             state[k] = v.to(device)


#### Change following line to model.eval() for inference
model.train()


Training Routine

In [None]:
loss_history = []
start = time.time()
for epoch in range(50):

    start_epoch = time.time()
    print('Epoch {}'.format(epoch+51))
    running_loss = 0.0
    n = 0
#    scheduler.step()
    for i, data in enumerate(dataloader,0):

        inputs,labels = data
        inputs,labels = inputs.to(device),labels.to(device)
#        print(inputs.size(),labels.size())
      
        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs,labels.long())
        loss.backward()        
        optimizer.step()
        
        
        running_loss += loss.item()
        if i % (len(dataloader)/10) == 0:
            print('{}%'.format(n))
            n += 10
    print('loss: {}'.format(running_loss/len(dataloader)))
    print('Epoch time: {}\n'.format(time.time() - start_epoch))
    loss_history.append(running_loss/len(dataloader))
    
    
    if (epoch + 1) % 5 == 0:
        f,a = plt.subplots(ncols=2,figsize=(30,20))
        im = inputs.detach().cpu().numpy()[0,0,:,:]
        out_raw = outputs.detach().cpu().numpy()[0,1,:,:]
        out = F.softmax(outputs).detach().cpu().numpy()[0,1,:,:]
        a[0].imshow(im,plt.cm.gray)
        a[0].imshow(out_raw,plt.cm.jet,alpha=0.3)
        a[0].set_title('Raw   Max={}'.format(np.max(out_raw)))
        a[1].imshow(im,plt.cm.gray)
        a[1].imshow(out,plt.cm.jet,alpha=0.3)
        a[1].set_title('Softmax  Max: {}'.format(out.max()))
        f.savefig(save_name+'_e{}.jpg'.format(epoch+51))
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        },save_name+'_e{}.pt'.format(epoch+51))
print('Finished Training')
print('Total Time: {}'.format(time.time() - start))


### Plot loss vs epoch
f,a = plt.subplots(figsize=(10,10)) 
a.set_xlabel('Epoch') 
a.set_ylabel('Loss') 
#a.set_title('UNet 04032019') 
a.plot(range(len(loss_history)),loss_history) 
f.savefig(save_name+'_Loss.jpg')

Inference Routine

In [None]:
start = time.time()
predictions_soft = np.zeros((len(dataset),2,1024,1024))
raw = np.zeros((len(dataset),1,1024,1024))
n = 0

for i, data in enumerate(dataloader,0):
    print(i)
    inputs = data[0]
#    raw[n*batch_size:n*batch_size+len(inputs),:,:,:] = inputs.detach().cpu().numpy()
    p = model(inputs.cuda())
    #p = F.softmax(p)
    p = p.detach().cpu().numpy()
    predictions_soft[n*batch_size:n*batch_size+len(inputs),:,:,:] = p
    n += 1
    
print(time.time() - start)