<a href="https://colab.research.google.com/github/jhy9968/ECE6179_project/blob/main/2_blocks_all_sets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install --quiet "matplotlib" "pytorch-lightning" "pandas" "torchmetrics"
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch import Tensor
import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch.nn import functional as F
from torchmetrics import MeanSquaredError
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
import torchvision.models.resnet as RN

from torchvision.models import resnet18
from torchvision.models.resnet import ResNet18_Weights

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import os
import torch
from torchvision.datasets import STL10 as STL10
import torchvision.transforms as transforms
from torch.utils.data import random_split
!pip install torchmetrics
from torchmetrics import Accuracy
import torchvision
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from google.colab import drive
drive.mount('/content/drive/')

####### CHANGE TO APPROPRIATE DIRECTORY TO STORE DATASET
root_dir = "/content/drive/Shareddrives/ECE6179_project/"
dataset_dir = root_dir + "CNN-VAE/data/"
#For MonARCH
# dataset_dir = "/mnt/lustre/projects/ds19/SHARED"

#All images are 3x96x96
image_size = 96
#Example batch size
batch_size = 32

[K     |████████████████████████████████| 708 kB 9.4 MB/s 
[K     |████████████████████████████████| 529 kB 55.1 MB/s 
[K     |████████████████████████████████| 5.9 MB 52.2 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.8.2+zzzcolab20220929150707 requires tensorboard<2.9,>=2.8, but you have tensorboard 2.10.1 which is incompatible.[0m
[?25hLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Mounted at /content/drive/


In [2]:
#Perform random crops and mirroring for data augmentation
transform_train = transforms.Compose(
    [transforms.RandomCrop(image_size, padding=4),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transform_unlabelled = transforms.Compose(
    [transforms.RandomHorizontalFlip(p=0.5),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#No random 
transform_test = transforms.Compose(
    [transforms.CenterCrop(image_size),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [3]:
#Load train and validation sets
trainval_set = STL10(dataset_dir, split='train', transform=transform_train, download=True)

#Use 10% of data for training - simulating low data scenario
num_train = int(len(trainval_set)*0.1)

#Split data into train/val sets
torch.manual_seed(0) #Set torch's random seed so that random split of data is reproducible
train_set, val_set = random_split(trainval_set, [num_train, len(trainval_set)-num_train])

#Load test set
test_set = STL10(dataset_dir, split='test', transform=transform_test, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
unlabelled_set = STL10(dataset_dir, split='unlabeled', transform=transform_unlabelled, download=True)

Files already downloaded and verified


In [30]:
batch_size = 128
train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
unlabelled_loader = DataLoader(unlabelled_set, shuffle=True, batch_size=batch_size)

val_loader = DataLoader(val_set, batch_size=batch_size)
test_loader  = DataLoader(test_set, batch_size=batch_size)

print('Size of train loader:', len(train_loader))
print('Size of unlabelled loader:', len(unlabelled_loader))
print('Size of valid loader:', len(val_loader))
print('Size of test loader:', len(test_loader))

Size of train loader: 4
Size of unlabelled loader: 782
Size of valid loader: 36
Size of test loader: 63


In [31]:
train_data_iter = iter(unlabelled_loader)
data,labels = train_data_iter.next()
print(data.shape)
print(labels.shape)

torch.Size([128, 3, 96, 96])
torch.Size([128])


In [32]:
class Refined_Unlabelled_Dataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.images = torch.load(img_dir)
        self.transform = transform

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        image = self.images[idx]
        label = -1
        if self.transform:
            image = self.transform(image)
        return image, label

root_dir = "/content/drive/Shareddrives/ECE6179_project/"
dataset_dir = root_dir + "CNN-VAE/data/"
refined_unlabelled_set = Refined_Unlabelled_Dataset(img_dir=dataset_dir+'refined_images.pt')

refined_unlabelled_loader = DataLoader(refined_unlabelled_set, batch_size=batch_size, shuffle=True)

In [8]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


# Model classes

In [33]:
class Encoder(nn.Module):
  def __init__ (self, base_model=resnet18(weights=ResNet18_Weights.IMAGENET1K_V1, progress=False)):
    super(Encoder, self).__init__()
    self.block1 = nn.Sequential(*list((base_model.children()))[:5])
    self.block2 = nn.Sequential(*list((base_model.children()))[5])
    # self.block3 = nn.Sequential(*list((base_model.children()))[6])
    # self.block4 = nn.Sequential(*list((base_model.children()))[7])

  def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)
    # x = self.block3(x)
    # x = self.block4(x)
    return x

  def print_model(self):
    print(self)

  def freeze_param(self, block):
    for i, child in enumerate(self.children()):
      # print(i)
      if i == block-1:
        for param in child.parameters():
          param.requires_grad = False
          # print(param.shape)
    print('Freeze block '+str(block)+' parameters')
    
  def unfreeze_param(self, block):
    for i, child in enumerate(self.children()):
      if i == block-1:
        for param in child.parameters():
          param.requires_grad = True
    print('Unfreeze block '+str(block)+' parameters')

In [34]:
class Decoder(nn.Module):  # Fixed inpplanes
  def __init__ (self, inplanes = 128, intMed_planes = 64):
    super(Decoder, self).__init__()
    self.inplanes      = inplanes
    self.intMed_planes = intMed_planes

    # self.convTrans1 = nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    # self.conv2 = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)
    # self.convTrans3 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    # self.conv4 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)
    self.convTrans5 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv6 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)   
    self.convTrans7 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv8 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
    self.convTrans9 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv10 = nn.Conv2d(in_channels = 64, out_channels = 3, kernel_size = 3, stride = 1, padding = 1)   

 # Output padding is here to match the size. It needs to be careful on this extra line of zeros when building the loss function.

  def forward (self, x):
    
    # x = self.convTrans1(x)
    # x = self.conv2(x)
    # x = self.convTrans3(x)
    # x = self.conv4(x)
    x = self.convTrans5(x)
    x = self.conv6(x)
    x = self.convTrans7(x)
    x = self.conv8(x)
    x = self.convTrans9(x)
    x = self.conv10(x)

    return x



In [35]:
class AutoEncoder(LightningModule):
  def __init__ (self, learning_rate = 1e-4, encoder=Encoder(), decoder=Decoder(), trainDataLoader=None, valDataLoader=None, testDataLoader=None):
    super().__init__()

    self.learning_rate = learning_rate
    self.loss_fun = nn.MSELoss()

    self.Encoder = encoder
    self.Decoder = decoder

    self.train_accuracy = MeanSquaredError()
    self.vald_accuracy = MeanSquaredError()
    self.test_accuracy = MeanSquaredError()

    self.trainDataLoader = trainDataLoader
    self.valDataLoader = valDataLoader
    self.testDataLoader = testDataLoader

    os.mkdir(full_dir + 'runs/')
    self.writer = SummaryWriter(full_dir + 'runs/')
    
  def forward(self, x):
    out = self.Encoder(x)
    out = self.Decoder(out)
    out_flattened = out.view(x.shape[0], -1)
    return out_flattened

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    x_flattened = x.view(x.shape[0], -1)

    loss = self.loss_fun(logits, x_flattened)
    self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    #self.log("train_acc", self.train_accuracy, prog_bar=True, on_step=False, on_epoch=True)

    self.writer.add_scalars('Loss', {'Training' : loss}, self.current_epoch)

    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    x_flattened = x.view(x.shape[0], -1)

    loss = self.loss_fun(logits, x_flattened)
    self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    #self.log("val_acc", self.val_accuracy, prog_bar=True, on_step=False, on_epoch=True) 

    self.writer.add_scalars('Loss', {'Validation' : loss}, self.current_epoch)
  
  def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    x_flattened = x.view(x.shape[0], -1)

    loss = self.loss_fun(logits, x_flattened)
    self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    #self.log("test_acc", self.test_accuracy, prog_bar=True, on_step=False, on_epoch=True) 

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    return optimizer
  
  def train_dataloader(self):    
    return self.trainDataLoader
  
  def val_dataloader(self):
    return self.valDataLoader
  
  def test_dataloader(self):
    return self.testDataLoader

  def extractSubModules(self):
    return self.Encoder, self.Decoder

# Training the autoencoder

In [None]:
Max_Epochs = 100
Case_Dir = "2_blocks_all_sets/Stage1/"
full_dir = root_dir + Case_Dir

autoencoder = AutoEncoder(learning_rate = 1e-4, encoder=Encoder(), decoder=Decoder(), trainDataLoader=unlabelled_loader, valDataLoader=val_loader, testDataLoader=test_loader)

checkpoint_callback = ModelCheckpoint(monitor = "val_acc",
                                      dirpath = full_dir,
                                      save_top_k=1,
                                      mode="max",
                                      every_n_epochs=1
                                      )

print(full_dir)

trainer = Trainer(
    accelerator="auto",
    devices = 1 if torch.cuda.is_available() else None,
    max_epochs = Max_Epochs,
    callbacks = [TQDMProgressBar(refresh_rate=20)],
    logger=CSVLogger(save_dir= full_dir),
    deterministic=True
)

ckpt_path = full_dir + ""

trainer.fit(autoencoder, ckpt_path=)

# Evaluate Model
trainer.test()

# Save Encoder & Decoder

# autoencoder.freeze()
model_path = full_dir + "Models/"
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')

enc_train_1, dec_train_1 = autoencoder.extractSubModules()
torch.save(autoencoder.state_dict(), model_path+timestamp+"_auto.pth")
enc, dec = autoencoder.extractSubModules()
torch.save(enc.state_dict(), model_path+timestamp+"_enc.pth")
torch.save(dec.state_dict(), model_path+timestamp+"_dec.pth")

/content/drive/Shareddrives/ECE6179_project/2_blocks_all_sets/Stage1/


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name           | Type             | Params
----------------------------------------------------
0 | loss_fun       | MSELoss          | 0     
1 | Encoder        | Encoder          | 683 K 
2 | Decoder        | Decoder          | 223 K 
3 | train_accuracy | MeanSquaredError | 0     
4 | vald_accuracy  | MeanSquaredError | 0     
5 | test_accuracy  | MeanSquaredError | 0     
----------------------------------------------------
906 K     Trainable params
0         Non-trainable params
906 K     Total params
3.625     Total estimated mod

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]