In [None]:

import torch
import torchvision
import numpy as np
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
import warnings
# import UserWarning

warnings.simplefilter("ignore", UserWarning)

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.get_device_name(0)


'Tesla V100-SXM2-16GB'

In [None]:

training_set = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('data', train=True, download=True, transform=torchvision.transforms.ToTensor()),
                     batch_size=32, shuffle=True)
validation_set = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('data', train=False, download=True, transform=torchvision.transforms.ToTensor()),
                     batch_size=32, shuffle=False)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [None]:


class MaskedConv2d(torch.nn.Conv2d):
    def __init__(self, mask_type, *arg, **kwargs):
        super().__init__(*arg, **kwargs)
        assert mask_type in ["A", "B"], "Invalid mask type"
        self.batch_size, self.channels, self.height, self.width = self.weight.data.shape
        self.mask_type = mask_type
        
        in_red_channel_index = torch.arange(0, self.in_channels, 3)
        in_green_channel_index = torch.arange(1, self.in_channels, 3)
        in_blue_channel_index = torch.arange(2, self.in_channels, 3)
        
        out_red_channel_index = torch.arange(0, self.out_channels, 3)
        out_green_channel_index = torch.arange(1, self.out_channels, 3)
        out_blue_channel_index = torch.arange(2, self.out_channels, 3)

        # in_red_channel_index = torch.arange(0, self.in_channels//3*1)
        # in_green_channel_index = torch.arange(self.in_channels//3*1, self.in_channels//3*2)
        # in_blue_channel_index = torch.arange(self.in_channels//3*2, self.in_channels)
        
        # out_red_channel_index = torch.arange(0, self.out_channels//3*1)
        # out_green_channel_index = torch.arange(self.out_channels//3*1, self.out_channels//3*2)
        # out_blue_channel_index = torch.arange(self.out_channels//3*2, self.out_channels)
        
        self.mask = torch.ones_like(self.weight.data)
        self.mask[:,:,self.height//2:, self.width//2:] = 0
        self.mask[:, :, self.height//2+1:, :] = 0
        red_mask_template = self.mask[0].clone()
        green_mask_template = self.mask[0].clone()
        blue_mask_template = self.mask[0].clone()
        if mask_type == "A":
#             self.mask[red_channel_index, :, self.height//2+1:, :] = 0
            green_mask_template[in_red_channel_index, self.height//2, self.width//2] = 1
            self.mask[out_green_channel_index] = green_mask_template
            blue_mask_template[torch.cat((in_red_channel_index, in_green_channel_index)), self.height//2, self.width//2] = 1
            self.mask[out_blue_channel_index] = blue_mask_template.clone()
        
        elif mask_type == "B":
            red_mask_template[in_red_channel_index, self.height//2, self.width//2] = 1
            self.mask[out_red_channel_index] = red_mask_template
            green_mask_template[torch.cat((in_red_channel_index, in_green_channel_index)), self.height//2, self.width//2] = 1
            self.mask[out_green_channel_index] = green_mask_template
            blue_mask_template[:, self.height//2, self.width//2] = 1
            self.mask[out_blue_channel_index] = blue_mask_template
        self.mask = self.mask.to(device=device)
        
            
    def forward(self, images):
        self.weight.data = self.weight.data * self.mask
        return super(MaskedConv2d, self).forward(images)




In [None]:


class PixelCNN(torch.nn.Module):

    def __init__(self, in_channels, hidden_channel, out_channels):
        super(PixelCNN, self).__init__()
        self.hidden_channel = hidden_channel
        self.conv1 = MaskedConv2d(mask_type='A', in_channels=in_channels,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm1 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv2 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm2 = torch.nn.BatchNorm2d(hidden_channel)
        
        self.conv3 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm3 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv4 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm4 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv5 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm5 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv6 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm6 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv7 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm7 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv8 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm8 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv9 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm9 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv10 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm10 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv11 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm11 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv12 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm12 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv13 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=hidden_channel, kernel_size=7, stride=1, padding=3)
        self.batch_norm13 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv14 = MaskedConv2d(mask_type='B', in_channels=hidden_channel,  out_channels=1024*3, kernel_size=1, stride=1, padding=0)
        self.batch_norm14 = torch.nn.BatchNorm2d(1024*3)

        self.conv15 = MaskedConv2d(mask_type='B', in_channels=1024*3,  out_channels=1024*3, kernel_size=1, stride=1, padding=0)
        self.batch_norm15 = torch.nn.BatchNorm2d(1024*3)

        self.conv_red = torch.nn.Conv2d(in_channels=1024, out_channels=out_channels, kernel_size=1)
        self.conv_green = torch.nn.Conv2d(in_channels=1024, out_channels=out_channels, kernel_size=1)
        self.conv_blue = torch.nn.Conv2d(in_channels=1024, out_channels=out_channels, kernel_size=1)

        # self.conv16 = torch.nn.Conv2d(in_channels=1026,  out_channels=256*3, kernel_size=1, stride=1, padding=0)
        self.red_channel_index = torch.arange(0, 1024*3, 3).long()
        self.green_channel_index = torch.arange(1, 1024*3, 3).long()
        self.blue_channel_index = torch.arange(2, 1024*3, 3).long()


            
    def forward(self, images):
        pred = self.conv1(images)
        pred = self.batch_norm1(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv2(pred)
        pred = self.batch_norm2(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv3(pred)
        pred = self.batch_norm3(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv4(pred)
        pred = self.batch_norm4(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv5(pred)
        pred = self.batch_norm5(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv6(pred)
        pred = self.batch_norm6(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv7(pred)
        pred = self.batch_norm7(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv8(pred)
        pred = self.batch_norm8(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv9(pred)
        pred = self.batch_norm9(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv10(pred)
        pred = self.batch_norm10(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv11(pred)
        pred = self.batch_norm11(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv12(pred)
        pred = self.batch_norm12(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv13(pred)
        pred = self.batch_norm13(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv14(pred)
        pred = self.batch_norm14(pred)
        pred = torch.nn.ReLU()(pred)

        pred = self.conv15(pred)
        pred = self.batch_norm15(pred)
        pred = torch.nn.ReLU()(pred)

        # print(pred.shape)
        # print(self.red_channel_index )
        red_channel = pred[:, self.red_channel_index]
        green_channel = pred[:, self.green_channel_index]
        blue_channel = pred[:, self.blue_channel_index]

        # print(red_channel.shape)
        # print(green_channel.shape)
        # print(blue_channel.shape)

        red_pred = self.conv_red(red_channel)
        green_pred = self.conv_green(green_channel)
        blue_pred = self.conv_blue(blue_channel)

        # print(red_pred.shape)
        # print(green_pred.shape)
        # print(blue_pred.shape)
        pred = torch.stack((red_pred, green_pred, blue_pred), dim=2)

        # pred = self.conv16(pred)
        # pred = pred.reshape(images.shape[0], 256, 3, 32, 32)
        return pred

    def sample(self, shape, count, label=None, device='cuda'):
        channels, height, width = shape

        samples = torch.zeros(count, *shape).to(device)
        if label is None:
            labels = torch.randint(high=10, size=(count,)).to(device)
        else:
            labels = (label*torch.ones(count)).to(device).long()

        with torch.no_grad():
            for i in tqdm(range(height)):
                for j in range(width):
                    for c in range(channels):
                        unnormalized_probs = self.forward(samples)
                        pixel_probs = torch.softmax(unnormalized_probs[:, :, c, i, j], dim=1)
                        sampled_levels = torch.multinomial(pixel_probs, 1).squeeze().float() / 255
                        samples[:, c, i, j] = sampled_levels

        return samples


pixel_cnn = PixelCNN(in_channels=3, hidden_channel=128*3, out_channels=256)
pixel_cnn.to(device=device)
optim = torch.optim.Adam(params=pixel_cnn.parameters())
train_loss = []
validation_loss = []



In [None]:

def save_model(epoch, model, optim, train_loss, save_path):
    check_point = {
        "epoch": epoch, 
        "model_state_dict": model.state_dict(),
        "optim_state_dict": optim.state_dict(),
        "train_loss": train_loss,
    }
    torch.save(check_point, save_path)
    print("Save model succesfully")

    

In [None]:


def generate_images(nb_images, images_shape, model, device="cpu"):
    generated_images = torch.zeros((nb_images, images_shape[2], images_shape[0], images_shape[1]), device="cuda").float()
    with torch.no_grad():
        for i in tqdm(range(images_shape[0])):
            for j in range(images_shape[1]):
                for c in range(images_shape[2]):
                    pred = pixel_cnn(generated_images)
                    # print(pred.shape)
                    # pred = pred.reshape(16, 256, 3, 32, 32)
                    # print(pred.shape)
                    prob = torch.nn.functional.softmax(pred[:,:,c,i,j], dim=1)
                    generated_images[:,c,i,j] = torch.multinomial(prob, 1).float().flatten() 
                # blue_pred = pixel_cnn(generated_images, channel="B")
                # blue_prob = torch.nn.functional.softmax(blue_pred[:,:,i,j], dim=1)
                # generated_images[:,2,i,j] = torch.multinomial(blue_prob, 1).float().flatten() 

    generated_images = torchvision.utils.make_grid(generated_images, nrow=int(np.sqrt(nb_images)))
    generated_images = generated_images.permute(1,2,0).detach().cpu().numpy() / 255.0
    plt.figure(figsize=(7.5,7.5))
    plt.imshow(generated_images)
    plt.title("Generated image")
    plt.show()
    return generated_images

# generated_images = generate_images(16, (32,32,3), pixel_cnn, device=device)


In [None]:

# a = generated_images[10].permute(1,2,0).detach().cpu().numpy()
# plt.imshow(a/255)

In [None]:

epochs = 100

for epoch in range(epochs):
    print(f"Epoch {epoch}\n")
    # Training 
    print("Start traning....")
    epoch_train_loss = []
    pixel_cnn.train()
    for images, labels in tqdm(training_set):
        images = images.to(device=device) 
        pred = pixel_cnn(images)
        # print(pred.shape)
#         pred = pred.reshape(images.shape[0], 256, 3, 32, 32)
        loss = torch.nn.CrossEntropyLoss()(pred, (images*255.0).long())
        optim.zero_grad()
        loss.backward()
        optim.step()
        epoch_train_loss.append(loss.item())
    
    print(f"Finish epoch {epoch}. Train_loss = {np.mean(epoch_train_loss)}")
    # if (epoch % 5 == 0):
    #     save_model(epoch, pixel_cnn, optim, train_loss, f"./drive/MyDrive/pred_cnn/pixel_cnn{epoch}.pt")

    generate_images(nb_images=36, images_shape=(32,32,3), model=pixel_cnn, device=device)
    print("------------------------------------------------------------------------------------")
    print("\n\n")


In [None]:

# check_point = torch.load("./drive/MyDrive/pred_cnn/pixel_cnn.pt")

# print(list(check_point.keys()))
# pixel_cnn.load_state_dict(check_point["model_state_dict"])
# optim.load_state_dict(check_point["optim_state_dict"])
# train_loss = check_point["train_loss"]
