<a href="https://colab.research.google.com/github/jhy9968/ECE6179_project/blob/main/unfreeze_fine_tuning_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torchvision.datasets import STL10 as STL10
import torchvision.transforms as transforms
from torch.utils.data import random_split
!pip install torchmetrics
from torchmetrics import Accuracy
import torchvision
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from google.colab import drive
drive.mount('/content/drive/')

####### CHANGE TO APPROPRIATE DIRECTORY TO STORE DATASET
root_dir = "/content/drive/Shareddrives/ECE6179_project/"
dataset_dir = root_dir + "CNN-VAE/data/"
#For MonARCH
# dataset_dir = "/mnt/lustre/projects/ds19/SHARED"

#All images are 3x96x96
image_size = 96
#Example batch size
batch_size = 32

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmetrics
  Downloading torchmetrics-0.9.3-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 36.5 MB/s 
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.9.3
Mounted at /content/drive/


In [None]:
#Perform random crops and mirroring for data augmentation
transform_train = transforms.Compose(
    [transforms.RandomCrop(image_size, padding=4),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transform_unlabelled = transforms.Compose(
    [transforms.RandomHorizontalFlip(p=0.5),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#No random 
transform_test = transforms.Compose(
    [transforms.CenterCrop(image_size),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


In [None]:
#Load train and validation sets
trainval_set = STL10(dataset_dir, split='train', transform=transform_train, download=True)

#Use 10% of data for training - simulating low data scenario
num_train = int(len(trainval_set)*0.1)

#Split data into train/val sets
torch.manual_seed(0) #Set torch's random seed so that random split of data is reproducible
train_set, val_set = random_split(trainval_set, [num_train, len(trainval_set)-num_train])

#Load test set
test_set = STL10(dataset_dir, split='test', transform=transform_test, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
unlabelled_set = STL10(dataset_dir, split='unlabeled', transform=transform_unlabelled, download=True)

Files already downloaded and verified


In [None]:
from torch.utils.data import DataLoader

batch_size = 50
train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
# unlabelled_loader = DataLoader(unlabelled_set, shuffle=True, batch_size=batch_size)

val_loader = DataLoader(val_set, batch_size=batch_size)
test_loader  = DataLoader(test_set, batch_size=batch_size)

print('Size of train loader:', len(train_loader))
# print('Size of unlabelled loader:', len(unlabelled_loader))
print('Size of valid loader:', len(val_loader))
print('Size of test loader:', len(test_loader))

Size of train loader: 10
Size of valid loader: 90
Size of test loader: 160


In [None]:
train_data_iter = iter(unlabelled_loader)
data,labels = train_data_iter.next()
print(data.shape)
print(labels.shape)

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
! pip install --quiet "matplotlib" "pytorch-lightning" "pandas" "torchmetrics"
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch import Tensor
import torch.nn as nn
from torchmetrics import Accuracy
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch.nn import functional as F
from torchmetrics import MeanSquaredError
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
import torchvision.models.resnet as RN

from torchvision.models import resnet18
from torchvision.models.resnet import ResNet18_Weights

In [None]:
class Encoder(nn.Module):
  def __init__ (self, base_model=resnet18(weights=ResNet18_Weights.IMAGENET1K_V1, progress=False)):
    super(Encoder, self).__init__()
    self.block1 = nn.Sequential(*list((base_model.children()))[:5])
    self.block2 = nn.Sequential(*list((base_model.children()))[5])
    self.block3 = nn.Sequential(*list((base_model.children()))[6])
    self.block4 = nn.Sequential(*list((base_model.children()))[7])

  def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)
    return x

  def print_model(self):
    print(self)

  def freeze_param(self, block):
    for i, child in enumerate(self.children()):
      print(i)
      if i == block-1:
        for param in child.parameters():
          param.requires_grad = False
          print(param.shape)
    print('Freeze block '+str(block)+' parameters')
    
  def unfreeze_param(self, block):
    for i, child in enumerate(self.children()):
      if i == block-1:
        for param in child.parameters():
          param.requires_grad = True
    print('Unfreeze block '+str(block)+' parameters')

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


In [None]:
class Decoder(nn.Module):  # Fixed inpplanes
  def __init__ (self, inplanes = 512, intMed_planes = 64):   # Changed inplanes from 64 to 512 to match the new Encoder
    super(Decoder, self).__init__()
    self.inplanes      = inplanes
    self.intMed_planes = intMed_planes

    self.convTrans1 = nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv2 = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)
    self.convTrans3 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv4 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)
    self.convTrans5 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv6 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)   
    self.convTrans7 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv8 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
    self.convTrans9 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 2, padding = 1, output_padding=1)
    self.conv10 = nn.Conv2d(in_channels = 64, out_channels = 3, kernel_size = 3, stride = 1, padding = 1)   

 # Output padding is here to match the size. It needs to be careful on this extra line of zeros when building the loss function.

  def forward (self, x):
    
    x = self.convTrans1(x)
    x = self.conv2(x)
    x = self.convTrans3(x)
    x = self.conv4(x)
    x = self.convTrans5(x)
    x = self.conv6(x)
    x = self.convTrans7(x)
    x = self.conv8(x)
    x = self.convTrans9(x)
    x = self.conv10(x)

    return x



In [None]:
class AutoEncoder(LightningModule):
  def __init__ (self, learning_rate = 1e-4, encoder=Encoder(), decoder=Decoder(), trainDataLoader=None, valDataLoader=None, testDataLoader=None):
    super().__init__()

    self.learning_rate = learning_rate
    self.loss_fun = nn.MSELoss()

    self.Encoder = encoder
    self.Decoder = decoder

    self.train_accuracy = MeanSquaredError()
    self.vald_accuracy = MeanSquaredError()
    self.test_accuracy = MeanSquaredError()

    self.trainDataLoader = trainDataLoader
    self.valDataLoader = valDataLoader
    self.testDataLoader = testDataLoader
    

  def forward(self, x):
#    print(x.shape)
    out = self.Encoder(x)
#    print(out.shape)
    out = self.Decoder(out)
#    print(out.shape)
    out_flattened = out.view(x.shape[0], -1)
#    print(out_flattened.shape)
    return out_flattened

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    x_flattened = x.view(x.shape[0], -1)

    loss = self.loss_fun(logits, x_flattened)
    self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    #self.log("train_acc", self.train_accuracy, prog_bar=True, on_step=False, on_epoch=True)

    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    x_flattened = x.view(x.shape[0], -1)

    loss = self.loss_fun(logits, x_flattened)
    self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    #self.log("val_acc", self.val_accuracy, prog_bar=True, on_step=False, on_epoch=True) 
  
  def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    x_flattened = x.view(x.shape[0], -1)

    loss = self.loss_fun(logits, x_flattened)
    self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    #self.log("test_acc", self.test_accuracy, prog_bar=True, on_step=False, on_epoch=True) 

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    return optimizer
  
  def train_dataloader(self):    
    return self.trainDataLoader
  
  def val_dataloader(self):
    return self.valDataLoader
  
  def test_dataloader(self):
    return self.testDataLoader

  def extractSubModules(self):
    return self.Encoder, self.Decoder

In [None]:
Case_Dir = "Mod_AE_v004/"
path = root_dir+ Case_Dir+ "Models/"
enc_rec4_statdic = torch.load(path+"20221002044719_AE_v004_enc.pth")
dec_rec4_statdic = torch.load(path+"20221002044719_AE_v004_dec.pth")

encoder_train_4 = Encoder()
encoder_train_4.load_state_dict(enc_rec4_statdic)

decoder_train_4 = Decoder()
decoder_train_4.load_state_dict(dec_rec4_statdic)

encoder_train_4.freeze_param(1)
encoder_train_4.freeze_param(2)
# encoder_train_4.freeze_param(3)
# encoder_train_4.freeze_param(4)



class AutoEncoder_Classifier(LightningModule):
  def __init__ (self, learning_rate = 1e-3,
                enc = encoder_train_4, 
                trainDataLoader=train_loader, valDataLoader=val_loader, testDataLoader=test_loader):
    super().__init__()

    self.learning_rate = learning_rate
    self.loss_fun = nn.CrossEntropyLoss()

    self.Encoder = enc

    self.act     = F.relu
    self.GAP     = nn.AdaptiveAvgPool2d(8)
    self.linear1 = nn.Linear(512*9, 256)
    self.linear2 = nn.Linear(256, 10)



    self.train_accuracy = Accuracy()
    self.val_accuracy = Accuracy()
    self.test_accuracy = Accuracy()

    self.trainDataLoader = trainDataLoader
    self.valDataLoader = valDataLoader
    self.testDataLoader = testDataLoader
    

  def forward(self, x):
    out = self.Encoder(x)
    out = out.view(out.shape[0], -1)
    out = self.linear1(out)
    out = self.act(out)
    out = self.linear2(out) 

    return out

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    loss = self.loss_fun(logits, y)

    preds = logits.argmax(1)
    self.train_accuracy.update(preds, y)

    self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    self.log("train_acc", self.train_accuracy, prog_bar=True, on_step=False, on_epoch=True)

    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    loss = self.loss_fun(logits, y)

    preds = logits.argmax(1)
    self.val_accuracy.update(preds, y)
    self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    self.log("val_acc", self.val_accuracy, prog_bar=True, on_step=False, on_epoch=True) 
  
  def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    loss = self.loss_fun(logits, y)

    preds = logits.argmax(1)
    self.test_accuracy.update(preds, y)
    self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    self.log("test_acc", self.test_accuracy, prog_bar=True, on_step=False, on_epoch=True) 

  def configure_optimizers(self):

    optimizer = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate)

    return optimizer
  
  def train_dataloader(self):    
    return self.trainDataLoader
  
  def val_dataloader(self):
    return self.valDataLoader
  
  def test_dataloader(self):
    return self.testDataLoader

  def extractSubModules(self):
    return self.Encoder, self.linear1, self.linear2

In [None]:
Max_Epochs = 400
Case_Dir = "Mod_AE_v004/Classifier/"
checkpoint_callback = ModelCheckpoint(monitor = "val_acc",
                                      dirpath = root_dir+ Case_Dir,
                                      save_top_k=1,
                                      mode="max",
                                      every_n_epochs=1
                                      )
print(root_dir+ Case_Dir)

train_loader_mod = DataLoader(train_set, shuffle=True, batch_size=batch_size)


model_AE_Classifier_v002 = AutoEncoder_Classifier(learning_rate = 2e-5, trainDataLoader=train_loader, valDataLoader=val_loader, testDataLoader=test_loader)
checkpoint_dir = root_dir+ Case_Dir+"lightning_logs/version_6/checkpoints/epoch=58-step=590.ckpt"

trainer_AE_Classfier = Trainer(
    accelerator="auto",
    devices = 1 if torch.cuda.is_available() else None,
    max_epochs = Max_Epochs,
    callbacks = [TQDMProgressBar(refresh_rate=20)],
    logger=CSVLogger(save_dir= root_dir+ Case_Dir),
    deterministic=False,
    log_every_n_steps=10
)

trainer_AE_Classfier.fit(model_AE_Classifier_v002, ckpt_path=checkpoint_dir )

# Evaluate Model
trainer_AE_Classfier.test()

# Save Encoder & Decoder

model_AE_Classifier_v002.freeze()
model_path = root_dir+ Case_Dir+ "Models/"
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')

enc, lin1, lin2 = model_AE_Classifier_v002 .extractSubModules()
torch.save(model_AE_Classifier_v002.state_dict(), model_path+timestamp+"_AEC_v002_mod.pth")
torch.save(enc.state_dict(), model_path+timestamp+"_AEC_v002_enc.pth")
torch.save(lin1.state_dict(), model_path+timestamp+"_AEC_v002_lin1.pth")
torch.save(lin2.state_dict(), model_path+timestamp+"_AEC_v002_lin2.pth")