In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pdb
import math

torch.cuda.empty_cache()

def get_random_crop(image, crop_height, crop_width):

    max_x = image.shape[1] - crop_width
    max_y = image.shape[0] - crop_height

    x = np.random.randint(0, max_x)
    y = np.random.randint(0, max_y)

    crop = image[y: y + crop_height, x: x + crop_width]

    return crop, x, y

def searchForFocus(filename, substring):
    with open(filename, 'r') as file:
        data = file.read()
        location = data.find(substring)
        croppedStr = data[location+len(substring):]
        # Split at spaces and find first number
        for word in croppedStr.split(): # Split at spaces
            # Delete any commas    
            word = word.replace(',', "")
            try:
                focusPosition = int(word)
                return focusPosition
            except ValueError:
                continue
    file.close()

class Dataset(torch.utils.data.Dataset):
    # ids indicates what subfolders (samples) to access
    def __init__(self, foldername, subfolderPrefix, ids):
        self.foldername = foldername
        self.subfolderPrefix = subfolderPrefix
        self.ids = ids
        
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        sampleFoldername = self.foldername + '/' + self.subfolderPrefix + str(index)
        
        # H, W
        cropSize = (640, 640)
        
        images = []
        for i, prefix in enumerate(['before', 'after']):
        
            # Load in image as [0,1] array
            image = cv2.imread(sampleFoldername + '/' + prefix + str(index) + '.tif', 0) * 1 / 255.0

            # Shift it so is from [-1,1]
            image *= 2
            image -= 1
            
            if i == 0:
                # Randomly crop the image
                image, cornerX, cornerY = get_random_crop(image, cropSize[0], cropSize[1])
            else:
                # Crop the label image to the same region as the input
                image = image[cornerY:cornerY + cropSize[0], cornerX:cornerX + cropSize[1]]
            
            temp = torch.from_numpy(image)
            if i == 0:
                temp = temp.unsqueeze(0) # Add fake first dimension to specify 1-channel
            images.append(temp)
            
        return images
    

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

params = {'batch_size': 2,
          'shuffle': True,
          'num_workers': 2}

# Randomly partition the full list into a training set and validation set
numSamples = 100 # total number of samples collected
frac = 1/5 # fraction to be validation
np.random.seed(0)
permutedIds = np.random.permutation(range(numSamples))
splitPoint = int((1-frac) * len(permutedIds))
trainingIds = permutedIds[:splitPoint]
valIds = permutedIds[splitPoint:]

training_set = Dataset('/home/aofeldman/Desktop/AFdataCollection', 'sample', trainingIds)
training_generator = torch.utils.data.DataLoader(training_set, **params)

validation_set = Dataset('/home/aofeldman/Desktop/AFdataCollection', 'sample', valIds)
#validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [None]:
X, Y = training_set.__getitem__(1)

def imshow(img,wait):
    img = img / 2 + 0.5     # unnormalize
    npimg = np.squeeze(img.numpy())
    width = int(0.15 * npimg.shape[1])
    height = int(0.15 * npimg.shape[0])
    cv2.imshow("Hi",cv2.resize(npimg, (width, height)))
    cv2.waitKey(wait)
    cv2.destroyAllWindows()
imshow(X, 10000)
print(X.shape)
imshow(Y, 10000)
print(Y.shape)

In [None]:
# Consider placing dropout layers after conv2d layers (conv2d -> batchnorm2d -> leakyReLU -> dropout(p=0.1))
# And also place after fully connected layers (linear -> leakyReLU -> dropout(p=0.3))
# TODO: Should actually figure out appropriate amount of padding for layers 

net = nn.Sequential(
    # Encoder section
    
    # Does not change channel dimensions
    nn.Conv2d(1, 4, kernel_size=9, stride=1, padding=4),
    nn.BatchNorm2d(4),
    nn.LeakyReLU(negative_slope = 0.1, inplace=True),
    # 1/4 channel dimensions
    nn.MaxPool2d(kernel_size=4, stride=4),
    # Does not change channel dimensions
    nn.Conv2d(4, 4, kernel_size=7, stride=1, padding=3),
    nn.BatchNorm2d(4),
    nn.LeakyReLU(negative_slope = 0.1, inplace=True),
    # 1/4 channel dimensions
    nn.MaxPool2d(kernel_size=2, stride=2),
    # Does not change channel dimensions
    nn.Conv2d(4, 4, kernel_size=5, stride=1, padding=2),
    nn.BatchNorm2d(4),
    nn.LeakyReLU(negative_slope = 0.1, inplace=True),
    # 1/4 channel dimensions
    nn.MaxPool2d(kernel_size=2, stride=2),
    
    # At this point:
    # Each channel has dimensions
    # H_out, W_out = (1/4)^3 * (H_in, W_in)
    
    # Decoder section
    nn.ConvTranspose2d(4, 4, kernel_size=5, stride=1, padding=2),
    nn.BatchNorm2d(4),
    nn.LeakyReLU(negative_slope = 0.1, inplace=True),
    nn.Upsample(scale_factor = 2, mode='bilinear'),
    
    nn.ConvTranspose2d(4, 4, kernel_size=7, stride=1, padding=2),
    nn.BatchNorm2d(4),
    nn.LeakyReLU(negative_slope = 0.1, inplace=True),
    nn.Upsample(scale_factor = 2, mode='bilinear'),
    
    nn.ConvTranspose2d(4, 4, kernel_size=9, stride=1, padding=2),
    nn.BatchNorm2d(4),
    nn.LeakyReLU(negative_slope = 0.1, inplace=True),
    nn.Upsample(scale_factor = 2, mode='bilinear'),
    
#     nn.ConvTranspose2d(4, 4, kernel_size=3, stride=1, padding=1),
#     nn.BatchNorm2d(4),
#     nn.LeakyReLU(negative_slope = 0.1, inplace=True),
#     nn.Upsample(scale_factor = (2, 2)),
    
    nn.ConvTranspose2d(4, 1, kernel_size=3, stride=1, padding=1),
    nn.Tanh(),
)

net = net.to(device)

In [None]:
from collections import OrderedDict

class EncoderBlock(nn.Module):
    def __init__(self, dimIn, dimOut, kernel, leakySlope, poolSize, use_norm):
        super(EncoderBlock, self).__init__()
        block = [nn.Conv2d(dimIn, dimOut, kernel_size=kernel, stride=1, padding= (kernel-1) // 2)]
        
        if use_norm:
            block += [nn.BatchNorm2d(dimOut)]
        block += [nn.LeakyReLU(negative_slope = leakySlope, inplace=True),
                  nn.MaxPool2d(kernel_size=poolSize, stride=poolSize)]
        
        self.block = nn.Sequential(*block)
                
    def forward(self, x):
        return self.block(x)
    
class DecoderBlock(nn.Module):
    def __init__(self, dimIn, dimOut, kernel, leakySlope, scale, use_norm):
        super(DecoderBlock, self).__init__()
        block = [nn.ConvTranspose2d(dimIn, dimOut, kernel_size=kernel, stride=1, padding= (kernel-1) // 2)]
        
        if use_norm:
            block += [nn.BatchNorm2d(dimOut)]
        block += [nn.LeakyReLU(negative_slope = leakySlope, inplace=True),
                  nn.Upsample(scale_factor = scale, mode = 'bilinear')]
        
        self.block = nn.Sequential(*block)
            
    def forward(self, x, earlierX = None):
        #print('Called Decoder forward')
        if earlierX is not None:
            #print('Shape of x: ', x.size())
            #print('Shape of earlierX: ', earlierX.size())
            combinedChannels = torch.cat([x, earlierX], 1)
            #print('Shape of combinedChannels: ', combinedChannels.size())
            return self.block(combinedChannels)
        else:
            return self.block(x)
        
class EndBlock(nn.Module):
    def __init__(self, dimIn, kernel):
        super(EndBlock, self).__init__()
        
        self.block = \
        nn.Sequential(nn.ConvTranspose2d(dimIn, 1, kernel_size=kernel, padding= (kernel-1) // 2), nn.Tanh())
        
    def forward(self, x):
        return self.block(x)
    
class Net(nn.Module):
    def __init__(self, numEncoder, numDecoder):
        super(Net, self).__init__()
        
        layers = [('e0', EncoderBlock(1, 16, 9, 0.1, 4, True))]
        #self.layers = [EncoderBlock(1, 4, 5, 0.1, 4, True, 'e0')]

        for i in range(1, numEncoder):
            layers += [('e' + str(i), EncoderBlock(16, 16, 7, 0.1, 4, True))]
            #self.layers += [EncoderBlock(4, 4, 5, 0.1, 4, True, 'e' + str(i))]

        for j in range(numDecoder):
            added = 1
            layers += [('d' + str(j), DecoderBlock(16 + added, 16, 7, 0.1, 4, True))]
            #self.layers += [DecoderBlock(4 + added, 4, 5, 0.1, 4, True, 'd' + str(j))]

        #layers += [('f', nn.ConvTranspose2d(4, 1, kernel_size=5, padding=2), nn.Tanh())]
        layers += [('f', EndBlock(16, 5))]
        
        #self.layers += [nn.ConvTranspose2d(4, 1, kernel_size=5, padding=2), nn.Tanh()]

        print('layers', layers)
        self.model = nn.Sequential(*[layers[i][1] for i in range(len(layers))])
        
        self.numEncoder = numEncoder
        self.numDecoder = numDecoder
        self.layers = layers
                
    def forward(self, x):
        # print('Calling forward')
        layerOutputs = {}
        prevVal = x.clone()
        for (name, group) in self.layers:
            #print('On layer: ' + block[0])
            if name[0] == 'd':
                # TODO: Finish this!
                # Pass in a randomly cropped portion of the image that aligns with the current size
                #print('x:', x.size())
                #print('prevVal:', prevVal.size())
                height, width = x.size()[2:]
                resizeHeight, resizeWidth = prevVal.size()[2:]
                
                max_x = width - resizeWidth
                max_y = height - resizeHeight

                cornerX = np.random.randint(0, max_x)
                cornerY = np.random.randint(0, max_y)
                
                smallX = x[:, :, cornerY:cornerY + resizeHeight, cornerX:cornerX + resizeWidth]
                layerOutputs[name] = group.forward(prevVal, smallX)
            else:
                layerOutputs[name] = group.forward(prevVal)
            prevVal = layerOutputs[name].clone()
            
        return layerOutputs['f']
        
net = Net(3, 3).to(device)

In [None]:
net.layers[0][1].block[0].weight

In [None]:
for p in net.parameters():
    print(p.data.shape)

In [None]:
count = 0
for p in net.parameters():
    n_params = np.prod(list(p.data.shape)).item()
    count += n_params
    print(p.data)
print(f'total params: {count}')

In [None]:
import torch.optim as optim

criterion = nn.MSELoss()
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.RMSprop(net.parameters())

In [None]:
# For early stopping, use blocks of epochs (say size 3), and a window (say size 10) of how far to look in the past.
# Then, compare the average loss on the current block to the block in the past. Another idea would be compare the variances.
# If the fraction is not sufficiently small, halt.

max_epochs = 100
# Should be stated as a fraction of the previous error
#max_frac = 0.999
#window = 10 # How many epochs in the past to compare to
#block = 5 # What size block of epochs to use

learnFreq = 10
batch_multiplier = 4

epoch_training_loss = []
epoch_val_loss = []


for epoch in range(max_epochs):
    print("\nOn epoch: " + str(epoch))

    count = 0
    
    net.train()
    for inputs, labels in training_generator:
        
        inputs, labels = inputs.to(device).float(), labels.to(device).float()
        
        if count == 0:
            optimizer.step()
            # zero the parameter gradients
            optimizer.zero_grad()
            count = batch_multiplier

        # forward + backward + optimize
        outputs = net(inputs).to(device)
        loss = criterion(torch.squeeze(outputs), labels) / batch_multiplier
        loss.backward()
        #optimizer.step()
        
        count -=1
        
        loss = loss.detach()
        inputs = inputs.detach()
        outputs = outputs.detach()
        
        # print('Batch loss: ', loss.item())

        #print('Outputs', torch.squeeze(outputs))
        #print('Batch size: ' + str(len(inputs)))
        # Multiply by the batch size and batch_multiplier (because earlier divided)
        #running_loss += loss.item() * len(inputs) * batch_multiplier
        #print('Batch average loss ' + str(loss.item()))
    #training_loss = running_loss / training_set.__len__()
    #print("Epoch training loss: " + str(training_loss))
    #epoch_training_loss.append(training_loss)
    if epoch % learnFreq == 0 or epoch == (max_epochs - 1):
        with torch.no_grad():
            net.eval()
            for i, dataset in enumerate([training_set, validation_set]):
                if i == 0:
                    print('\nEpoch training results ')
                else:
                    print('\nEpoch validation results ')
                MSE = 0
                avgAbsDev = 0
                for sample in dataset.ids:
                    X, y = dataset.__getitem__(sample)
                    X = X.unsqueeze(0) # Add fake batch dimension
                    X = X.to(device).float()
        
                    yHat = net(X).to(device)
                    yHat = yHat.to('cpu')
                    MSE += torch.norm(y - yHat)**2 / np.product(y.numpy().shape)
                    #avgAbsDev += np.abs(yHat - y)
        
                #print('Sample: ' + str(sample))
                #print('y: ' + str(y))
                #print('yHat: ' + str(yHat))
                MSE /= len(dataset.ids)
                #avgAbsDev /= len(dataset.ids)
                print('RMSE on dataset: ' + str(np.sqrt(MSE)))
                #print('Avg Abs Dev on dataset: ' + str(avgAbsDev))
                if i == 0:
                    epoch_training_loss.append(MSE)
                else:
                    epoch_val_loss.append(MSE)
        
# for epoch in range(max_epochs):
#     print("\nOn epoch: " + str(epoch))
        
#     net.train()
#     for inputs, labels in training_generator:
#         # zero the parameter gradients
#         optimizer.zero_grad()
#         print('inputs.shape: ', inputs.size())
        
#         # forward + backward + optimize
#         outputs = net(inputs.float())
#         print('outputs.shape: ', outputs.size())
#         print('labels.shape: ', labels.float().size())
#         loss = criterion(torch.squeeze(outputs), labels.float())
#         loss.backward()
#         optimizer.step()
#         print('Batch Loss: ' + str(loss.item()))
        
#     with torch.no_grad():
#         val_loss = 0
#         net.eval()
#         for ind in validation_set.ids:
#             X, y = validation_set.__getitem__(ind)
#             X = X.unsqueeze(0) # Add fake batch dimension
#             yHat = net(X.float())
#             val_loss += torch.norm(y - yHat)**2 / np.product(y.numpy().shape)
#         val_loss /= validation_set.__len__()
#         epoch_val_loss.append(val_loss)
#     print("Epoch validation loss: " + str(val_loss))
        
#     if len(epoch_val_loss) >= window + block + 1:
#         latestBlock = np.mean(epoch_val_loss[-1:-1-block])
#         earlierBlock = np.mean(epoch_val_loss[-1-window:-1-window-block])
        
#         # latestBlock must be sufficiently smaller than earlierBlock
#         if latestBlock / earlierBlock > max_frac:
#             print('Converged')
#             pdb.set_trace()
#             break
            
print('Finished Training')

In [None]:
with torch.no_grad():
    net.eval()
    for i, dataset in enumerate([training_set, validation_set]):
        if i == 0:
            print('\nTraining results ')
        else:
            print('\nValidation results ')
        MSE = 0
        for sample in dataset.ids:
            X, y = dataset.__getitem__(sample)
            
            X = X.to(device).unsqueeze(0) # Add fake batch dimension
            yHat = net(X.float()).to(device)
            yHat = yHat.to('cpu')
            imshow(X.to('cpu'), 10000)
            imshow(y, 10000)
            imshow(yHat, 10000)
            print(yHat)
            # squared frobenius norm
            MSE += torch.norm(y - yHat)**2 / np.product(y.numpy().shape)
        MSE /= validation_set.__len__()
    print('RMSE on dataset: ' + str(np.sqrt(MSE)))

In [None]:
%matplotlib qt5
plt.figure()
plt.plot(range(max_epochs), epoch_training_loss, range(max_epochs), epoch_val_loss)
plt.legend(['Training', 'Validation'])