<a href="https://colab.research.google.com/github/migdashn/Deep-Learning-Projects/blob/main/Colorozation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Imports

In [None]:
import pandas
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from sklearn.utils import shuffle
from torchvision import datasets, transforms
from glob import glob

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class.

link : [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)

In [None]:

cifar_data = datasets.CIFAR10('data', train=True, download=True, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(cifar_data, batch_size=64, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting data/cifar-100-python.tar.gz to data


STL-10 dataset:

*10 classes: airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck.

*Images are 96x96 pixels, color.

*500 training images (10 pre-defined folds), 800 test images per class.

*100000 unlabeled images for unsupervised learning. These examples are   extracted from a similar but broader distribution of images. For instance, it contains other types of animals (bears, rabbits, etc.) and vehicles (trains, buses, etc.) in addition to the ones in the labeled set.

STL-10 link : [STL-10](https://cs.stanford.edu/~acoates/stl10/)

In [None]:
#TODO checks to the data to see how to use it

To convert a colorized Tensor to black and white, you can use the grayscale method, which takes a 3-channel image and converts it to a single-channel image by combining the values of the red, green, and blue channels using a weighted average. Here's an example of how you can convert a color image represented as a Tensor to a grayscale image using PyTorch:

In [None]:
# Create a color image Tensor
color_tensor = torch.rand(3, 256, 256)

# Convert the Tensor to grayscale using the grayscale method
gray_tensor = torch.mean(color_tensor, dim=0, keepdim=True)

Alternatively, you can use the following formula to convert RGB image to grayscale image where Y = 0.299 * R + 0.587 * G + 0.114 * B

In [None]:
gray_tensor = 0.299 * color_tensor[0,:,:] + 0.587 * color_tensor[1,:,:] + 0.114 * color_tensor[2,:,:]

In [None]:
#TODO create a black and white data using one of the methods

Generator

## **use this to see the model : [kaggle source code](https://www.kaggle.com/code/utkarshsaxenadn/landscape-colorizer-pix2pixgan/notebook#Discriminator)**

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, latent):
        super(Encoder, self).__init__()
        self.conv = nn.Conv2d(input, input//2, kernel_size=(4,4), stride=(2,2))
        self.norm = nn.BatchNorm2d(input//2)
        self.lr = nn.LeakyReLU(0.2)

    def forward(self,inputs,ifNorm):
        x = self.conv(inputs)
        if ifNorm == True:
          x = self.norm(x)
        x = self.lr(x)
        return x




In [None]:
class Decoder(nn.Module):
    def __init__(self, input, ifDrop):
        super(Decoder, self).__init__()
        self.conv = nn.ConvTranspose2d(input, input*2, kernel_size=(4,4), stride=(2,2))
        self.norm = nn.BatchNorm2d(input*2)
        self.drop = nn.Dropout2d(0.5)
        self.rel = nn.ReLU()

    def forward(self,inputs,ifDrop):
        x = self.conv(inputs)
        x = self.norm(x)
        if ifDrop == True:
          x = self.drop(x)
        x = self.rel(x)
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels):
        super(Generator, self).__init__()   
        self.enc0 = nn.Sequential( nn.Conv2d(input, input//2, kernel_size=(4,4), stride=(2,2)),          #encoder without batchNorm
                                          nn.LeakyReLU(0.2))
        self.enc1 = nn.Sequential( nn.Conv2d(input, input//2, kernel_size=(4,4), stride=(2,2)),           #encoder with batchNorm
                                          nn.BatchNorm2d(input//2),
                                          nn.LeakyReLU(0.2))
        self.dec0 = nn.Sequential(nn.ConvTranspose2d(input, input*2, kernel_size=(4,4), stride=(2,2)),    #decoder without dropout
                                  nn.BatchNorm2d(input*2),
                                  nn.ReLU())
        self.dec1 = nn.Sequential(nn.ConvTranspose2d(input, input*2, kernel_size=(4,4), stride=(2,2)),    #decoder with dropout
                                  nn.BatchNorm2d(input*2),
                                  nn.Dropout2d(0.5),
                                  nn.ReLU())

    def forward(self, inputs):
      x = inputs 
      
      return x



Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_dim, output_size):
        super(Discriminator, self).__init__()
        self.enc0 = nn.Sequential( nn.Conv2d(input, input//2, kernel_size=(4,4), stride=(2,2)),          #encoder without batchNorm
                                          nn.LeakyReLU(0.2))
        self.enc1 = nn.Sequential( nn.Conv2d(input, input//2, kernel_size=(4,4), stride=(2,2)),           #encoder with batchNorm
                                          nn.BatchNorm2d(input//2),
                                          nn.LeakyReLU(0.2))
        self.zeropad = nn.ZeroPad2d()
        self.conv1 = nn.Conv2d(input_size,kernel_size=(4,4), stride=(1,1))
        self.conv2 = nn.Conv2d(input_size,kernel_size=(4,4))
        self.batchnorm = nn.BatchNorm2d(input_size)
        self.lr = nn.LeakyReLU(0.2)
    def forward(self,inputs):
        #cat?
        x = self.enc0(inputs)
        x = self.enc1(x)
        x = self.enc1(x)
        x = self.zeroPad(x)
        x = self.conv1(x)
        x = self.BatchNorm2d(x)
        x = self.lr(x)
        x = self.zeropad(x)
        x = self.conv2(x)
        return x;


Train model

### **taken fron nir's lecture need to modify**

In [None]:

def train(G, D, train_loader, z_size=100, lr=0.001, num_epochs=5): 
    d_optimizer = torch.optim.Adam(D.parameters(), lr=lr)
    g_optimizer = torch.optim.Adam(G.parameters(), lr=lr)


    # keep track of loss and generated, "fake" samples
    samples = []
    losses = []

    print_every = 400

    # Get some fixed data for sampling. These are images that are held
    # constant throughout training, and allow us to inspect the model's performance
    sample_size=16
    fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
    fixed_z = torch.from_numpy(fixed_z).float()

    # train the network
    D.train()
    G.train()
    for epoch in range(num_epochs):
        
        for batch_i, (real_images, _) in enumerate(train_loader):
                    
            batch_size = real_images.size(0)
            
            ## Important rescaling step - since the generator output layer is tanh ## 
            real_images = real_images*2 - 1  # rescale input images from [0,1) to [-1, 1)
            
            # ============================================
            #            TRAIN THE DISCRIMINATOR
            # ============================================
            
            d_optimizer.zero_grad()
            
            # 1. Train with real images

            # Compute the discriminator losses on real images 
            # smooth the real labels
            D_real = D(real_images)
            d_real_loss = real_loss(D_real, smooth=True)
            
            # 2. Train with fake images
            
            # Generate fake images
            z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            z = torch.from_numpy(z).float()
            fake_images = G(z)
            
            # Compute the discriminator losses on fake images        
            D_fake = D(fake_images)
            d_fake_loss = fake_loss(D_fake)
            
            # add up loss and perform backprop
            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            d_optimizer.step()
            
            
            # =========================================
            #            TRAIN THE GENERATOR
            # =========================================
            g_optimizer.zero_grad()
            
            # 1. Train with fake images and flipped labels
            
            # Generate fake images
            z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            z = torch.from_numpy(z).float()
            fake_images = G(z)
            
            # Compute the discriminator losses on fake images using flipped labels
            D_fake = D(fake_images)
            g_loss = real_loss(D_fake) # use real loss to flip labels
            
            # perform backprop
            g_loss.backward()
            g_optimizer.step()

            # Print some loss stats
            if batch_i % print_every == 0:
                # print discriminator and generator loss
                print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
                        epoch+1, num_epochs, d_loss.item(), g_loss.item()))

        
        ## AFTER EACH EPOCH##
        # append discriminator loss and generator loss
        losses.append((d_loss.item(), g_loss.item()))
        
        # generate and save sample, fake images
        G.eval() # eval mode for generating samples
        samples_z = G(fixed_z)
        samples.append(samples_z)
        G.train() # back to train mode

    # plot learning curve
    fig, ax = plt.subplots()
    losses = np.array(losses)
    plt.plot(losses.T[0], label='Discriminator')
    plt.plot(losses.T[1], label='Generator')
    plt.title("Training Losses")
    plt.legend()     
     
    return samples