In [1]:
import torch
torch.cuda.get_device_name(0), torch.cuda.device_count(), torch.cuda.is_available()

('GeForce GTX 1080 Ti', 1, True)

In [2]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Using device: cuda

GeForce GTX 1080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [3]:
import os

from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from torchvision import transforms
from torchsummary import summary 

import numpy as np


In [4]:
class Data(Dataset):
    def __init__(self, X, transform):
        self.X = X
        self.transform = transform 
        
    def __getitem__(self, index):
        sample = self.X[index]
        # sample = torch.tensor(r)
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample, sample
    
    def __len__(self):
        return len(self.X)

In [5]:
x_train = np.load('x_train.npy')
x_cv = np.load('x_cv.npy') 


img_transform = transforms.Compose([
    transforms.ToTensor(),
    # parameters mean, std are passed as 0.5, 0.5 in your case for each channe;
    # This will normalize the image in the range [-1,1]
    # image = (image - mean) / std 
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


"""

def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


def tensor_round(tensor):
    return torch.round(tensor)

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
    transforms.Lambda(lambda tensor:tensor_round(tensor))
])
"""

'\n\ndef min_max_normalization(tensor, min_value, max_value):\n    min_tensor = tensor.min()\n    tensor = (tensor - min_tensor)\n    max_tensor = tensor.max()\n    tensor = tensor / max_tensor\n    tensor = tensor * (max_value - min_value) + min_value\n    return tensor\n\n\ndef tensor_round(tensor):\n    return torch.round(tensor)\n\nimg_transform = transforms.Compose([\n    transforms.ToTensor(),\n    transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),\n    transforms.Lambda(lambda tensor:tensor_round(tensor))\n])\n'

In [6]:
batch_size = 248

In [7]:
x_processed_train = Data(x_train, img_transform)
x_processed_cv = Data(x_cv, img_transform)

X_train_data = DataLoader(x_processed_train, batch_size=batch_size, shuffle=True)
X_cv_data = DataLoader(x_processed_cv, batch_size=batch_size, shuffle=True)

In [8]:
"""
for data in X_train_data:
    # torch.Size([248, 64, 64, 3])
    print(data.shape)
"""

def to_img(x):
    x = x.view(x.size(0), 3, 64, 64)
    return x

def plot_sample_img(img, name):
    img = img.view(3, 64, 64)
    save_image(img, 'test/sample_{}.png'.format(name))
    

In [9]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score:
            self.counter += 1
            print("EarlyStopping counter: {} out of {}".format(self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print("Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...".format(self.val_loss_min, val_loss))
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [10]:
class vanilla_autoencoder(nn.Module):
    def __init__(self):
        super(vanilla_autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(nn.BatchNorm1d(num_features=3 * 64 * 64,
                                                    track_running_stats=False),
                                     nn.Linear(in_features=3 * 64 * 64,
                                               out_features=4000,
                                               bias=True),
                                     nn.ELU(alpha=1.0,
                                            inplace=True),
                                     nn.Dropout(p=0.5,
                                                inplace=False),
                                     nn.BatchNorm1d(num_features=4000,
                                                    track_running_stats=False),
                                     nn.Linear(in_features=4000,
                                               out_features=1000,
                                               bias=True),
                                     nn.ELU(alpha=1.0,
                                            inplace=True),
                                     nn.Dropout(p=0.5,
                                                inplace=False),
                                     nn.BatchNorm1d(num_features=1000,
                                                    track_running_stats=False),
                                     nn.Linear(in_features=1000,
                                               out_features=300,
                                               bias=True),
                                     nn.ELU(alpha=1.0,
                                            inplace=True), 
                                     nn.BatchNorm1d(num_features=300,
                                                    track_running_stats=False),
                                     nn.Linear(in_features=300,
                                               out_features=150,
                                               bias=True),
                                     nn.Sigmoid())
        
        self.decoder = nn.Sequential(nn.BatchNorm1d(num_features=150,
                                                    track_running_stats=False),
                                     nn.Linear(in_features=150,
                                               out_features=300,
                                               bias=True),
                                     nn.ELU(alpha=1.0,
                                            inplace=True),
                                     nn.Dropout(p=0.5,
                                                inplace=False),
                                     nn.BatchNorm1d(num_features=300,
                                                    track_running_stats=False),
                                     nn.Linear(in_features=300,
                                               out_features=1000,
                                               bias=True),
                                     nn.ELU(alpha=1.0,
                                            inplace=True),
                                     nn.Dropout(p=0.5,
                                                inplace=False),
                                     nn.BatchNorm1d(num_features=1000,
                                                    track_running_stats=False),
                                     nn.Linear(in_features=1000,
                                               out_features=4000,
                                               bias=True),
                                     nn.ELU(alpha=1.0,
                                            inplace=True), 
                                     nn.BatchNorm1d(num_features=4000,
                                                    track_running_stats=False),
                                     nn.Linear(in_features=4000,
                                               out_features=3 * 64 * 64,
                                               bias=True),
                                     nn.Sigmoid())

    def forward(self,
                x):
        x = self.encoder(x)
        x = self.decoder(x)
        
        return x

In [11]:
model = vanilla_autoencoder().to(device)

In [29]:
import matplotlib.pyplot as plt

In [32]:
optimizer = torch.optim.Adam(params=model.parameters(),
                             weight_decay=0)
LRStep = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                    patience=5,
                                                    verbose=True,
                                                    mode="min")

n_epochs = 50
patience = 10

model.load_state_dict(torch.load('checkpoint.pt'))

# to track the training loss as the model trains
train_losses = []
# to track the validation loss as the model trains
valid_losses = []
# to track the average training loss per epoch as the model trains
avg_train_losses = []
# to track the average validation loss per epoch as the model trains
avg_valid_losses = [] 

# initialize the early_stopping object
early_stopping = EarlyStopping(patience=patience, verbose=True)

criterion = nn.BCELoss()

for epoch in range(1, n_epochs + 1):

    ###################
    # train the model #
    ###################
    model.train() # prep model for training
    for data, label in X_train_data:
    
        # forward pass: compute predicted outputs by passing inputs to the model
              
        data = data.view(data.size(0), -1)
        data = Variable(data).to(device)
        
        data = data.type(torch.cuda.FloatTensor)
        output = model.forward(data).to(device)
        target = data
                
        
        loss = criterion(output, target)
        
        
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # record training loss
        train_losses.append(loss.item())
    
    # print('epoch [{}/{}], train_loss:{:.4f}'.format(epoch + 1, n_epochs, mse_loss.data))

    ######################    
    # validate the model #
    ######################
    model.eval() # prep model for evaluation
    for data, label in X_cv_data:
        
        data = data.view(data.size(0), -1)
        data = Variable(data).to(device)
        
        data = data.type(torch.cuda.FloatTensor)
        output = model.forward(data).to(device)
        target = data
        
        loss = criterion(output, target)
        
        # record validation loss
        valid_losses.append(loss.item())

    # print training/validation statistics 
    # calculate average loss over an epoch
    train_loss = np.average(train_losses)
    valid_loss = np.average(valid_losses)
    avg_train_losses.append(train_loss)
    avg_valid_losses.append(valid_loss)

    
    print_msg = "epoch [{}/{}], train_bce_loss: {:.5f}, val_bce_loss {:.5f}".format(epoch + 1, n_epochs, train_loss, valid_loss)

    print(print_msg)

    # clear lists to track next epoch
    train_losses = []
    valid_losses = []
    BCE_losses = []
    
    if epoch % 5 == 0:
        plt.imsave('track.png', (Variable(output).data).cpu().numpy()[0].reshape(64, 64, 3))

    # early_stopping needs the validation loss to check if it has decresed, 
    # and if it has, it will make a checkpoint of the current model
    #early_stopping(valid_loss, model)

    if early_stopping.early_stop:
        print("Early stopping")
        break


epoch [2/50], train_bce_loss: -12.29821, val_bce_loss -11.90012
epoch [3/50], train_bce_loss: -12.28460, val_bce_loss -11.91571
epoch [4/50], train_bce_loss: -12.26610, val_bce_loss -11.75790
epoch [5/50], train_bce_loss: -12.24715, val_bce_loss -11.98671
epoch [6/50], train_bce_loss: -12.17251, val_bce_loss -11.63616
epoch [7/50], train_bce_loss: -12.28583, val_bce_loss -11.58521
epoch [8/50], train_bce_loss: -12.35167, val_bce_loss -11.95718
epoch [9/50], train_bce_loss: -12.30204, val_bce_loss -11.76891
epoch [10/50], train_bce_loss: -12.37299, val_bce_loss -11.87274
epoch [11/50], train_bce_loss: -12.44932, val_bce_loss -11.83587
epoch [12/50], train_bce_loss: -12.46879, val_bce_loss -11.64072
epoch [13/50], train_bce_loss: -12.46968, val_bce_loss -11.58517
epoch [14/50], train_bce_loss: -12.50143, val_bce_loss -11.28991
epoch [15/50], train_bce_loss: -12.48203, val_bce_loss -11.48395
epoch [16/50], train_bce_loss: -12.50637, val_bce_loss -11.30765
epoch [17/50], train_bce_loss: -1

KeyboardInterrupt: 

In [None]:

# load the last checkpoint with the best model
model.load_state_dict(torch.load('checkpoint.pt'))