In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.models as models

In [2]:
class AutoencoderRGB(nn.Module):
    def __init__(self):
        super(AutoencoderRGB, self).__init__()
        # input size = 1024*1024
        self.resnet = models.resnet50(pretrained=True)
        # Remove the last linear layer and replace it with a convolutional layer
        self.encoder = nn.Sequential(*list(self.resnet.children())[:-1], nn.Conv2d(2048, 4096, kernel_size=1))
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4096, 2048, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(2048, 1024, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
        
    def forwardRGB(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
    def forwardGray(self, x):
        x = np.repeat(x[..., np.newaxis], 3, -1) # double check this
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
class AutoencoderGray(nn.Module):
    def __init__(self):
        super(AutoencoderGray, self).__init__()
        # input size = 1024*1024
        self.resnet = models.resnet50(pretrained=True)
        # Remove the last linear layer and replace it with a convolutional layer
        self.encoder = nn.Sequential(*list(self.resnet.children())[:-1], nn.Conv2d(2048, 4096, kernel_size=1))
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4096, 2048, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(2048, 1024, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

In [3]:
def trainRGB(model, train_data, valid_data, num_epochs=5, batch_size=64, learning_rate=1e-3, RGB = True, plot = True):
    torch.manual_seed(42)
    criterion = nn.CrossEntropyLoss() # mean square error loss
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5) 
    
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=True)
    train_outputs, valid_outputs = [], []
    for epoch in range(num_epochs):
        for train_data, train_label in train_loader:
            train_output = model.forwardRGB(img) if RGB else model.forwardGray(img)
            img = img.view(-1, 1024*1024*3)
            loss = criterion(train_output, train_label)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        for valid_data, valid_label in valid_loader:
            valid_output = model.forwardRGB(img) if RGB else model.forwardGray(img)
            loss_val = criterion(valid_output, valid_label)

        print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss)))
        outputs.append((epoch, img, output),)
        
        
    if plot:
        plt.title("Training/Validation loss")
        plt.plot(loss, label="Training Loss")
        plt.plot(val_losses, label="Validation Loss")
        plt.xlabel("Iterations")
        plt.legend(loc='best')
        plt.show()
        
        plt.figure()
        plt.title("Training/Validation loss")
        plt.plot(train_accs, label="Training Accuracy")
        plt.plot(val_accs, label="Validation Accuracy")
        plt.xlabel("Iterations")
        plt.legend(loc='best')
        plt.show()
         
    return train_outputs, valid_outputs