In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pdb
import math

torch.cuda.empty_cache()

def get_random_crop(image, crop_height, crop_width):

    max_x = image.shape[1] - crop_width
    max_y = image.shape[0] - crop_height

    x = np.random.randint(0, max_x)
    y = np.random.randint(0, max_y)

    crop = image[y: y + crop_height, x: x + crop_width]

    return crop, x, y

def searchForFocus(filename, substring):
    with open(filename, 'r') as file:
        data = file.read()
        location = data.find(substring)
        croppedStr = data[location+len(substring):]
        # Split at spaces and find first number
        for word in croppedStr.split(): # Split at spaces
            # Delete any commas    
            word = word.replace(',', "")
            try:
                focusPosition = int(word)
                return focusPosition
            except ValueError:
                continue
    file.close()

class Dataset(torch.utils.data.Dataset):
    # ids indicates what subfolders (samples) to access
    def __init__(self, foldername, subfolderPrefix, ids):
        self.foldername = foldername
        self.subfolderPrefix = subfolderPrefix
        self.ids = ids
        
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        sampleFoldername = self.foldername + '/' + self.subfolderPrefix + str(index)
        
        # H, W
        cropSize = (640, 640)
        
        images = []
        for i, prefix in enumerate(['before', 'after']):
        
            # Load in image as [0,1] array
            image = cv2.imread(sampleFoldername + '/' + prefix + str(index) + '.tif', 0) * 1 / 255.0

            # Shift it so is from [-1,1]
            image *= 2
            image -= 1
            
            if i == 0:
                # Randomly crop the image
                image, cornerX, cornerY = get_random_crop(image, cropSize[0], cropSize[1])
            else:
                # Crop the label image to the same region as the input
                image = image[cornerY:cornerY + cropSize[0], cornerX:cornerX + cropSize[1]]
            
            temp = torch.from_numpy(image)
            if i == 0:
                temp = temp.unsqueeze(0) # Add fake first dimension to specify 1-channel
            images.append(temp)
            
        return images
    

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

params = {'batch_size': 2,
          'shuffle': True,
          'num_workers': 2}

# Randomly partition the full list into a training set and validation set
numSamples = 100 # total number of samples collected
frac = 1/5 # fraction to be validation
np.random.seed(0)
permutedIds = np.random.permutation(range(numSamples))
splitPoint = int((1-frac) * len(permutedIds))
trainingIds = permutedIds[:splitPoint]
valIds = permutedIds[splitPoint:]

training_set = Dataset('/home/aofeldman/Desktop/AFdataCollection', 'sample', trainingIds)
training_generator = torch.utils.data.DataLoader(training_set, **params)

validation_set = Dataset('/home/aofeldman/Desktop/AFdataCollection', 'sample', valIds)
#validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [3]:
X, Y = training_set.__getitem__(1)

def imshow(img,wait):
    img = img / 2 + 0.5     # unnormalize
    npimg = np.squeeze(img.numpy())
    width = int(0.15 * npimg.shape[1])
    height = int(0.15 * npimg.shape[0])
    cv2.imshow("Hi",cv2.resize(npimg, (width, height)))
    cv2.waitKey(wait)
    cv2.destroyAllWindows()
#imshow(X, 10000)
#print(X.shape)
#imshow(Y, 10000)
#print(Y.shape)

In [4]:
class EncoderBlock(nn.Module):
    def __init__(self, dimIn, dimMid, dimOut, kernel, leakySlope):
        super(EncoderBlock, self).__init__()
        block = [nn.Conv2d(dimIn, dimMid, kernel_size=1, stride=1, padding = 0),
                 nn.BatchNorm2d(dimMid), nn.LeakyReLU(leakySlope)]
        block += [nn.Conv2d(dimMid, dimMid, kernel_size=kernel, stride=1, padding= (kernel-1) // 2),
                  nn.BatchNorm2d(dimMid), nn.LeakyReLU(leakySlope)]
        block += [nn.Conv2d(dimMid, dimOut, kernel_size=1, stride=1, padding=0),
                  nn.BatchNorm2d(dimOut), nn.LeakyReLU(0)]
        
        self.block = nn.Sequential(*block)
                
    def forward(self, x, earlierX = None):
        #print('Called Decoder forward')
        if earlierX is not None:
            #print('Shape of x: ', x.size())
            #print('Shape of earlierX: ', earlierX.size())
            combinedChannels = torch.cat([x, earlierX], 1)
            #print('Shape of combinedChannels: ', combinedChannels.size())
            return self.block(combinedChannels)
        else:
            return self.block(x)
    
class DecoderBlock(nn.Module):
    def __init__(self, dimIn, dimMid, dimOut, kernel, leakySlope, numLayers):
        super(DecoderBlock, self).__init__()
        block = []
        for i in range(numLayers):
            if i == 0:
                dim1 = dimIn
                dim2 = dimMid
            elif i == numLayers - 1:
                dim1 = dimMid
                dim2 = dimOut
            else:
                dim1 = dimMid
                dim2 = dimMid
            block += [nn.Conv2d(dim1, dim2, kernel_size=kernel, stride=1, padding=(kernel-1) // 2),
                 nn.BatchNorm2d(dim2), nn.LeakyReLU(leakySlope)]
        
        self.block = nn.Sequential(*block)
            
    def forward(self, x, earlierX = None):
        return self.block(x)

class StartBlock(nn.Module):
    def __init__(self, dimOut, kernel):
        super(StartBlock, self).__init__()
        
        self.block = nn.Sequential(nn.Conv2d(1, dimOut, kernel_size=kernel, padding = (kernel-1) // 2),
                                   nn.ReLU())
        
    def forward(self, x):
        return self.block(x)

class EndBlock(nn.Module):
    def __init__(self, dimIn, kernel):
        super(EndBlock, self).__init__()
        
        self.block = \
        nn.Sequential(nn.Conv2d(dimIn, 1, kernel_size=kernel, padding= (kernel-1) // 2), nn.ReLU())
        
    def forward(self, x):
        if earlierX is not None:
            x += earlierX
        return 2 * torch.clamp(1 - self.block(x), 0) - 1
            
class Net(nn.Module):
    def __init__(self, numEncoder, numDecoder):
        super(Net, self).__init__()
        
        layers = [('s', StartBlock(32, 3)), 
                  ('e0', EncoderBlock(32, 64, 32, 3, 0.1)), 
                  ('e1', EncoderBlock(64, 64, 32, 3, 0.1)), 
                  ('e2', EncoderBlock(64, 128, 32, 3, 0.1)), 
                  ('e3', EncoderBlock(64, 128, 32, 3, 0.1)), 
                  ('e4', EncoderBlock(64, 256, 32, 3, 0.1)), 
                  ('e5', EncoderBlock(64, 256, 32, 3, 0.1)), 
                  ('d0', DecoderBlock(32, 32, 16, 3, 0.1, 3)),
                  ('f', EndBlock(16, 3))]
        
        print('Layers: ', layers)
        self.model = nn.Sequential(*[layers[i][1] for i in range(len(layers))])
        
        self.numEncoder = numEncoder
        self.numDecoder = numDecoder
        self.layers = layers
                
    def forward(self, x):
        layerOutputs = {}

        for i, (name, group) in enumerate(self.layers):
            #print('On layer: ' + block[0])
            if name[0] == 'e' and name[1] != '0':
                layerOutputs[name] = group.forward(x, layerOutputs[self.layers[i-2][0]])
            else:
                layerOutputs[name] = group.forward(x)
            x = layerOutputs[name]
            
        return layerOutputs['f']
        
net = Net(3, 3).to(device)

Layers:  [('s', StartBlock(
  (block): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
)), ('e0', EncoderBlock(
  (block): Sequential(
    (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.1)
    (6): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0)
  )
)), ('e1', EncoderBlock(
  (block): Sequential(
    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slop

In [5]:
for p in net.parameters():
    print(p.data.shape)

torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([64, 32, 1, 1])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64, 64, 3, 3])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([32, 64, 1, 1])
torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([64, 64, 1, 1])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64, 64, 3, 3])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([32, 64, 1, 1])
torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([128, 64, 1, 1])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128, 128, 3, 3])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([32, 128, 1, 1])
torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([128, 64, 1, 1])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128, 128, 3, 3])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([32, 128, 1, 1])
torch.Size([32])
torch.Size([32])
torch.Size([32])
t

In [6]:
count = 0
for p in net.parameters():
    n_params = np.prod(list(p.data.shape)).item()
    count += n_params
    print(p.data)
print(f'total params: {count}')

tensor([[[[-0.0608,  0.1452, -0.2890],
          [-0.3223, -0.3267,  0.1308],
          [ 0.0603, -0.0154, -0.1141]]],


        [[[-0.2957,  0.1331,  0.2162],
          [ 0.1825, -0.3309,  0.0498],
          [ 0.0389,  0.3044,  0.1900]]],


        [[[ 0.2743, -0.1561, -0.1370],
          [ 0.2853,  0.2680,  0.0871],
          [ 0.0132,  0.2986,  0.0353]]],


        [[[ 0.3192,  0.1127, -0.0601],
          [-0.0219, -0.2803,  0.3321],
          [ 0.0710, -0.1584, -0.0685]]],


        [[[ 0.2697,  0.2063,  0.1326],
          [ 0.2347,  0.0794,  0.2464],
          [ 0.0056, -0.1145, -0.2902]]],


        [[[ 0.2857, -0.1263,  0.0675],
          [ 0.1497,  0.3255,  0.1640],
          [-0.2548,  0.1705,  0.0171]]],


        [[[-0.0086, -0.3112, -0.3126],
          [-0.2189,  0.0292,  0.3244],
          [-0.0239,  0.2021,  0.1512]]],


        [[[-0.2612, -0.0627, -0.2734],
          [ 0.1482, -0.2153, -0.2464],
          [-0.0333, -0.0575,  0.2519]]],


        [[[ 0.2571,  0.1906, -0.

tensor([[[[-0.1033]],

         [[ 0.0577]],

         [[-0.0413]],

         ...,

         [[ 0.0915]],

         [[-0.0731]],

         [[ 0.1006]]],


        [[[ 0.0224]],

         [[ 0.0468]],

         [[ 0.0844]],

         ...,

         [[-0.0489]],

         [[-0.1228]],

         [[ 0.0596]]],


        [[[ 0.0518]],

         [[ 0.0444]],

         [[ 0.0343]],

         ...,

         [[-0.1003]],

         [[-0.1137]],

         [[-0.0320]]],


        ...,


        [[[-0.1140]],

         [[-0.0509]],

         [[ 0.0214]],

         ...,

         [[-0.0181]],

         [[ 0.0971]],

         [[-0.1032]]],


        [[[-0.0053]],

         [[-0.0555]],

         [[ 0.0455]],

         ...,

         [[ 0.0908]],

         [[-0.1220]],

         [[-0.0464]]],


        [[[ 0.0244]],

         [[-0.0743]],

         [[ 0.0472]],

         ...,

         [[-0.0164]],

         [[-0.1162]],

         [[ 0.0738]]]], device='cuda:0')
tensor([ 0.0902,  0.0010, -0.0343, -0.1

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 

tensor([-0.0111, -0.0062, -0.0201,  0.0149,  0.0019,  0.0185,  0.0074,  0.0033,
         0.0080, -0.0142, -0.0060, -0.0039,  0.0120,  0.0078, -0.0166, -0.0203,
         0.0148,  0.0175, -0.0054,  0.0174,  0.0171,  0.0186,  0.0203, -0.0198,
         0.0050,  0.0075,  0.0082, -0.0019,  0.0028, -0.0100, -0.0208,  0.0036,
         0.0178, -0.0157,  0.0063, -0.0161, -0.0110,  0.0017,  0.0064,  0.0092,
        -0.0204,  0.0192,  0.0013, -0.0115, -0.0030, -0.0139, -0.0108,  0.0107,
        -0.0141, -0.0170, -0.0076,  0.0131,  0.0206,  0.0049,  0.0193,  0.0181,
        -0.0184,  0.0067,  0.0112,  0.0107, -0.0162,  0.0130,  0.0157,  0.0163,
        -0.0024, -0.0174,  0.0165,  0.0111, -0.0130,  0.0009,  0.0065,  0.0085,
         0.0162,  0.0087, -0.0163, -0.0015,  0.0056,  0.0021,  0.0119,  0.0091,
         0.0142,  0.0040, -0.0109, -0.0156,  0.0040,  0.0009, -0.0144,  0.0206,
         0.0208,  0.0111,  0.0200, -0.0112,  0.0136,  0.0114, -0.0146, -0.0044,
         0.0097,  0.0191, -0.0144,  0.00

In [7]:
import torch.optim as optim

criterion = nn.MSELoss()
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.RMSprop(net.parameters())

In [8]:
# For early stopping, use blocks of epochs (say size 3), and a window (say size 10) of how far to look in the past.
# Then, compare the average loss on the current block to the block in the past. Another idea would be compare the variances.
# If the fraction is not sufficiently small, halt.

max_epochs = 100
# Should be stated as a fraction of the previous error
#max_frac = 0.999
#window = 10 # How many epochs in the past to compare to
#block = 5 # What size block of epochs to use

learnFreq = 10
batch_multiplier = 4

epoch_training_loss = []
epoch_val_loss = []


for epoch in range(max_epochs):
    print("\nOn epoch: " + str(epoch))

    count = 0
    
    net.train()
    for inputs, labels in training_generator:
        
        inputs, labels = inputs.to(device).float(), labels.to(device).float()
        
        if count == 0:
            optimizer.step()
            # zero the parameter gradients
            optimizer.zero_grad()
            count = batch_multiplier

        # forward + backward + optimize
        outputs = net(inputs).to(device)
        loss = criterion(torch.squeeze(outputs), labels) / batch_multiplier
        loss.backward()
        #optimizer.step()
        
        count -=1
        
        loss = loss.detach()
        inputs = inputs.detach()
        outputs = outputs.detach()
        
        # print('Batch loss: ', loss.item())

        #print('Outputs', torch.squeeze(outputs))
        #print('Batch size: ' + str(len(inputs)))
        # Multiply by the batch size and batch_multiplier (because earlier divided)
        #running_loss += loss.item() * len(inputs) * batch_multiplier
        #print('Batch average loss ' + str(loss.item()))
    #training_loss = running_loss / training_set.__len__()
    #print("Epoch training loss: " + str(training_loss))
    #epoch_training_loss.append(training_loss)
    if epoch % learnFreq == 0 or epoch == (max_epochs - 1):
        with torch.no_grad():
            net.eval()
            for i, dataset in enumerate([training_set, validation_set]):
                if i == 0:
                    print('\nEpoch training results ')
                else:
                    print('\nEpoch validation results ')
                MSE = 0
                avgAbsDev = 0
                for sample in dataset.ids:
                    X, y = dataset.__getitem__(sample)
                    X = X.unsqueeze(0) # Add fake batch dimension
                    X = X.to(device).float()
        
                    yHat = net(X).to(device)
                    yHat = yHat.to('cpu')
                    MSE += torch.norm(y - yHat)**2 / np.product(y.numpy().shape)
                    #avgAbsDev += np.abs(yHat - y)
        
                #print('Sample: ' + str(sample))
                #print('y: ' + str(y))
                #print('yHat: ' + str(yHat))
                MSE /= len(dataset.ids)
                #avgAbsDev /= len(dataset.ids)
                print('RMSE on dataset: ' + str(np.sqrt(MSE)))
                #print('Avg Abs Dev on dataset: ' + str(avgAbsDev))
                if i == 0:
                    epoch_training_loss.append(MSE)
                else:
                    epoch_val_loss.append(MSE)
        
# for epoch in range(max_epochs):
#     print("\nOn epoch: " + str(epoch))
        
#     net.train()
#     for inputs, labels in training_generator:
#         # zero the parameter gradients
#         optimizer.zero_grad()
#         print('inputs.shape: ', inputs.size())
        
#         # forward + backward + optimize
#         outputs = net(inputs.float())
#         print('outputs.shape: ', outputs.size())
#         print('labels.shape: ', labels.float().size())
#         loss = criterion(torch.squeeze(outputs), labels.float())
#         loss.backward()
#         optimizer.step()
#         print('Batch Loss: ' + str(loss.item()))
        
#     with torch.no_grad():
#         val_loss = 0
#         net.eval()
#         for ind in validation_set.ids:
#             X, y = validation_set.__getitem__(ind)
#             X = X.unsqueeze(0) # Add fake batch dimension
#             yHat = net(X.float())
#             val_loss += torch.norm(y - yHat)**2 / np.product(y.numpy().shape)
#         val_loss /= validation_set.__len__()
#         epoch_val_loss.append(val_loss)
#     print("Epoch validation loss: " + str(val_loss))
        
#     if len(epoch_val_loss) >= window + block + 1:
#         latestBlock = np.mean(epoch_val_loss[-1:-1-block])
#         earlierBlock = np.mean(epoch_val_loss[-1-window:-1-window-block])
        
#         # latestBlock must be sufficiently smaller than earlierBlock
#         if latestBlock / earlierBlock > max_frac:
#             print('Converged')
#             pdb.set_trace()
#             break
            
print('Finished Training')


On epoch: 0


RuntimeError: CUDA out of memory. Tried to allocate 200.00 MiB (GPU 0; 3.95 GiB total capacity; 2.84 GiB already allocated; 197.69 MiB free; 3.04 GiB reserved in total by PyTorch)

In [None]:
with torch.no_grad():
    net.eval()
    for i, dataset in enumerate([training_set, validation_set]):
        if i == 0:
            print('\nTraining results ')
        else:
            print('\nValidation results ')
        MSE = 0
        for sample in dataset.ids:
            X, y = dataset.__getitem__(sample)
            
            X = X.to(device).unsqueeze(0) # Add fake batch dimension
            yHat = net(X.float()).to(device)
            yHat = yHat.to('cpu')
            imshow(X.to('cpu'), 10000)
            imshow(y, 10000)
            imshow(yHat, 10000)
            print(yHat)
            # squared frobenius norm
            MSE += torch.norm(y - yHat)**2 / np.product(y.numpy().shape)
        MSE /= validation_set.__len__()
    print('RMSE on dataset: ' + str(np.sqrt(MSE)))

In [None]:
%matplotlib qt5
plt.figure()
plt.plot(range(max_epochs), epoch_training_loss, range(max_epochs), epoch_val_loss)
plt.legend(['Training', 'Validation'])