In [3]:
import torch
import torch.utils.data as data
from torchvision import transforms
import torchvision.transforms.functional as TF
import torch.nn as nn
import matplotlib.pyplot as plt

"""
Here we explore data augmentation. TERMINER D'éCRIRE ET FAIRE RUN.
"""

class AE_small5(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_size = 2)
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 3, padding = 1) # 32 x 32 
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 24, kernel_size = 3, padding = 1) 
        self.conv3 = nn.Conv2d(in_channels = 24, out_channels = 48, kernel_size = 3, padding = 1) 
        self.conv4 = nn.Conv2d(in_channels = 48, out_channels = 48, kernel_size = 3, padding = 1) 
        self.conv5 = nn.Conv2d(in_channels = 48, out_channels = 48, kernel_size = 3, padding = 1) 
        
        self.deconv1 = nn.Conv2d(in_channels = 96, out_channels = 48, kernel_size = 3, padding = 1) 
        self.deconv2 = nn.Conv2d(in_channels = 96, out_channels = 48, kernel_size = 3, padding = 1) 
        self.deconv3 = nn.Conv2d(in_channels = 72, out_channels = 24, kernel_size = 3, padding = 1) 
        self.deconv4 = nn.Conv2d(in_channels = 40, out_channels = 16, kernel_size = 3, padding = 1) 
        self.deconv5 = nn.Conv2d(in_channels = 19, out_channels = 3, kernel_size = 3, padding = 1) 
        
        self.l_relu = nn.LeakyReLU(negative_slope = 0.1)
        self.upsample = nn.Upsample(scale_factor = (2, 2))
        self.linear = nn.Linear(32, 32)
        #self.dropout = nn.Dropout(0.5)
        
        
  
    def forward(self, x):
        # encode
        x1 = self.l_relu(self.conv1(x))
        x1 = self.l_relu(self.pool(x1))

        x2 = self.l_relu(self.conv2(x1))
        x2 = self.l_relu(self.pool(x2))

        x3 = self.l_relu(self.conv3(x2))
        x3 = self.l_relu(self.pool(x3))

        x4 = self.l_relu(self.conv4(x3))
        x4 = self.l_relu(self.pool(x4))

        x5 = self.l_relu(self.conv5(x4))
        #print(x5.shape)

        # decode
        y1 = torch.cat((x5, x4), dim = 1)
        y1 = self.l_relu(self.upsample(y1))
        y1 = self.l_relu(self.deconv1(y1))
        #print(y1.shape)

        y2 = torch.cat((y1, x3), dim = 1)
        y2 = self.l_relu(self.upsample(y2))
        y2 = self.l_relu(self.deconv2(y2))
        #print(y2.shape)

        y3 = torch.cat((y2, x2), dim = 1)
        y3 = self.l_relu(self.upsample(y3))
        y3 = self.l_relu(self.deconv3(y3))
        #print(y3.shape)

        y4 = torch.cat((y3, x1), dim = 1)
        y4 = self.l_relu(self.upsample(y4))
        y4 = self.l_relu(self.deconv4(y4))
        #print(y4.shape)

        y5 = torch.cat((y4, x), dim = 1)
        #print(y5.shape)
        y5 = self.linear(self.deconv5(y5))
        
        return y5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, SIZE, train = True, transform = None, switch_pixels = None):
        'Initialization'
        if train: 
            if SIZE > 50000: 
                print("SIZE is to big. It is set to SIZE = 50000.")
                SIZE = 50000
            x, y = torch.load("train_data.pkl")
            print("Training data : \n noisy_imgs_1 : ", x.shape, "\n noisy_imgs_2 : ", y.shape)
        else : 
            if SIZE > 1000:
                print("SIZE is to big. It is set to SIZE = 1000.")
                SIZE = 1000
            x, y = torch.load("val_data.pkl")
            print("Test data : \n noisy_imgs : ", x.shape, "\n clean_imgs : ", y.shape)
        x, y = x[:SIZE], y[:SIZE]
        print("Data reduced : \n noisy_imgs_1_reduced : ", x.shape, "\n noisy_imgs_2_reduced : ", y.shape)
        print("Type : ", x.dtype)
        if transform != None :
            print("With data augmentation : transform.")
        if switch_pixels != None :
            print("With data augmentation : switch pixels with n_max = ", switch_pixels[0], " and p = ", switch_pixels[1])
        self.x = x.float()
        self.y = y.float()
        self.transform = transform
        self.switch_pixels = switch_pixels

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.x)

  def __getitem__(self, index):
        'Generates one sample of data'
        # get label
        X_trans = self.x[index]
        Y_trans= self.y[index]

        seed = torch.randint(2147483647,(1,1)) # make a seed with generator 
        torch.manual_seed(seed.item()) # set the random seed for transforms
        if self.transform is not None:
            X_trans = self.transform(X_trans)

        torch.manual_seed(seed.item()) # set the random seed for transforms
        if self.transform is not None:
            Y_trans = self.transform(Y_trans)  
        
        torch.manual_seed(seed.item())
        if self.switch_pixels is not None:
            n_max, p = self.switch_pixels
            # n_max : maximum number of pixels that might switch. 
            # p : prob that the pixels are switched.
            if torch.rand((1,1)) < p:
                if n_max<50 :
                    n_max = 51
                # n : number of switched pixels, random in between 50 and n_max (not included)
                n = torch.randint(low = 50, high = n_max, size = (1,1))
                # index : random index of the n pixels that will be switched.
                index = torch.randint(low=0, high = X_trans.shape[1], size = (n,2))
                i,j = index[:, 0], index[:, 1]
                v_x = Y_trans[:, i, j].copy()
                v_y = X_trans[:, i, j].copy()

                X_trans[:, i, j] = v_x
                Y_trans[:, i, j] = v_y

        return X_trans, Y_trans


class MyRotateTransform(torch.nn.Module):
    def __init__(self, angles, p=0.8):
        self.angles = angles
        self.p = p

    def __call__(self, x):
        #angle = random.choice(self.angles)
        if torch.rand((1,1)) < self.p:
            rand_index = torch.randint(low=0, high = len(self.angles), size = (1,1))
            angle = self.angles[rand_index]
        else : angle = 0
        return TF.rotate(x, angle)


transform2 = transforms.RandomApply(torch.nn.ModuleList([
    # horizontal flip with probability p 
    transforms.RandomHorizontalFlip(p=0.8),
    # vertical flip with probability p 
    transforms.RandomVerticalFlip(p=0.8),
    # rotation of angle in angles with probility p
    MyRotateTransform(angles = [90, 180, 270], p=0.8)]),
    p=0.8) #randomly transform images with probability p


SIZE = 50000
BATCH_SIZE = 128
n_max = 250 #about a fourth of the total number of pixels in a image.
p = 0.5
train_set_aug = Dataset(SIZE, transform=transform2)
train_set = Dataset(SIZE)


# Model Initialization
model = AE_small5()
model_aug = AE_small5()
model_aug.load_state_dict(model.state_dict()) 
  
# Validation using MSE Loss function
loss_function = nn.MSELoss().to(device)
  
# Using an Adam Optimizer with lr = 0.001
optimizer = torch.optim.Adam(model.parameters(),
                             lr = 1e-3, betas=(0.9, 0.99))
optimizer_aug = torch.optim.Adam(model_aug.parameters(),
                             lr = 1e-3, betas=(0.9, 0.99))                            

# DataLoader is used to load the dataset 
# for training
loader_1 = torch.utils.data.DataLoader(dataset = train_set,
                                     batch_size = BATCH_SIZE,
                                     shuffle = True)
loader_aug = torch.utils.data.DataLoader(dataset = train_set_aug,
                                     batch_size = BATCH_SIZE,
                                     shuffle = True)                                    

#OPTIMIZATION
epochs = 10
outputs_aug = []
losses_aug = []
print("Training with augmentation : ")
for epoch in range(epochs):
    print("epoch : ", epoch)
    for noisy_imgs_1, noisy_imgs_2 in loader_aug:
        #print(noisy_imgs_1.shape)
        #print(noisy_imgs_2.shape)

        #noisy_imgs_1 = noisy_imgs_1.reshape(-1, 32 * 32)
        #noisy_imgs_2 = noisy_imgs_2.reshape(-1, 32 * 32)    
        # Output of Autoencoder
        #print("type : ", noisy_imgs_1.dtype)
        noisy_imgs_1 = noisy_imgs_1.to(device)
        noisy_imgs_2 = noisy_imgs_2.to(device)
        reconstructed = model_aug(noisy_imgs_1)
            
        # Calculating the loss function
        loss = loss_function(reconstructed, noisy_imgs_2)
            
        # The gradients are set to zero,
        # the the gradient is computed and stored.
        # .step() performs parameter update
        optimizer_aug.zero_grad()
        loss.backward()
        optimizer_aug.step()
        # Storing the losses in a list for plotting
        losses_aug.append(loss.detach().numpy())
    outputs_aug.append((epochs, noisy_imgs_2, reconstructed))


outputs = []
losses = []
print("Training without augmentation : ")
for epoch in range(epochs):
    print("epoch : ", epoch)
    for noisy_imgs_1, noisy_imgs_2 in loader_1:
        # Output of Autoencoder
        #print("type : ", noisy_imgs_1.dtype)
        noisy_imgs_1 = noisy_imgs_1.to(device)
        noisy_imgs_2 = noisy_imgs_2.to(device)
        reconstructed = model(noisy_imgs_1)
            
        # Calculating the loss function
        loss = loss_function(reconstructed, noisy_imgs_2)
            
        # The gradients are set to zero,
        # the the gradient is computed and stored.
        # .step() performs parameter update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Storing the losses in a list for plotting
        losses.append(loss.detach().numpy())
    outputs.append((epochs, noisy_imgs_2, reconstructed))

# Defining the Plot Style
plt.style.use('fivethirtyeight')
plt.xlabel('Iterations')
plt.ylabel('Loss')
  
# Plotting the last 100 values
plt.plot(losses[-100:])
plt.plot(losses_aug[-100:])
plt.legend(["Without augmentation", "With augmentation"])
plt.savefig("./Data_aug/AE_small5_losses")
plt.show()

PATH = "./Data_aug/AE_small5_model.pth"
torch.save(model.state_dict(), PATH)


PATH = "./Data_aug/AE_small5_model_aug.pth"
torch.save(model_aug.state_dict(), PATH)



cpu
Training data : 
 noisy_imgs_1 :  torch.Size([50000, 3, 32, 32]) 
 noisy_imgs_2 :  torch.Size([50000, 3, 32, 32])
Data reduced : 
 noisy_imgs_1_reduced :  torch.Size([50000, 3, 32, 32]) 
 noisy_imgs_2_reduced :  torch.Size([50000, 3, 32, 32])
Type :  torch.uint8
With data augmentation : transform.
Training data : 
 noisy_imgs_1 :  torch.Size([50000, 3, 32, 32]) 
 noisy_imgs_2 :  torch.Size([50000, 3, 32, 32])
Data reduced : 
 noisy_imgs_1_reduced :  torch.Size([50000, 3, 32, 32]) 
 noisy_imgs_2_reduced :  torch.Size([50000, 3, 32, 32])
Type :  torch.uint8
Training with augmentation : 
epoch :  0


KeyboardInterrupt: 