# Imports

In [1]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animat

# Seed Init

In [2]:
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results

Random Seed:  999


# Set Parameters

In [10]:
# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 64

# Size of feature maps in discriminator
ndf = 96

# Size of feature maps in generator
ngf = 96

# Number of training epochs
num_epochs = 100

# Learning rate for optimizers
lr=0.001

# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1


# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Dataset

In [11]:
batch_size = 256
image_size = 96

# Create a new transformation that resizes the images
transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


# Load STL-10 dataset
train_dataset = STL10(root='./data', split='train+unlabeled', transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(len(train_dataset))
print(len(train_loader))

test_dataset = STL10(root='./data', split='test', transform=transform, download=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
print(len(test_dataset))
print(len(test_loader))

105000
411
8000
32


# Weight Initialization

In [12]:
# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Encoder Model

In [13]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, ngpu, dim_z):
        super(Encoder, self).__init__()
        self.ngpu = ngpu
        nc = 3  # Number of input channels for the 96x96x3 image
        self.main = nn.Sequential(
            # input is (nc) x 96 x 96
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf) x 48 x 48
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*2) x 24 x 24
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*4) x 12 x 12
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*8) x 6 x 6
            nn.Conv2d(ndf * 8, dim_z, 6, 1, 0, bias=False)
        )

    def forward(self, input):
        z = self.main(input)
        return z




# Instantiate the encoder
netD = Encoder(ngpu=0, dim_z=64).to(device)

# Handle multi-GPU
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Randomly initialize all weights
netD.apply(weights_init)

# Print the model
print(netD)

Encoder(
  (main): Sequential(
    (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(96, 192, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(192, 384, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(384, 768, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(768, 64, kernel_size=(6, 6), stride=(1, 1), bias=False)
  )
)


# Decoder (Generator) Model

In [14]:
# Generator Code

class Decoder(nn.Module):
    def __init__(self, ngpu, dim_z):
        super(Decoder, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z 64x1x1, going into a convolution
            nn.ConvTranspose2d( dim_z, ngf * 8, 6, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size: (ndf*8) x 6 x 6
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size: (ndf*4) x 12 x 12
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size: (ndf*2) x 24 x 24
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size: (ndf) x 48 x 48
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # input is (nc) x 96 x 96
        )

    def forward(self, input):
        return self.main(input)


# Instantiate the decoder
netG = Decoder(ngpu=0, dim_z=64).to(device)

# Handle multi-GPU
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Randomly initialize all weights
netG.apply(weights_init)

# Print the model
print(netG)

Decoder(
  (main): Sequential(
    (0): ConvTranspose2d(64, 768, kernel_size=(6, 6), stride=(1, 1), bias=False)
    (1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(768, 384, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(384, 192, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(192, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(96, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)


# Criterion / Loss Function

In [15]:
criterion = nn.MSELoss()

# Optimizer

In [16]:
params = list(netD.parameters()) + list(netG.parameters())

optimizer = torch.optim.Adam(params, lr=lr, weight_decay=0)

# Training Loop

In [18]:
# Training loop

best_loss = float('inf')
best_model_state = None
num_train = len(train_dataset)




# Set up total loss/acc trackers
all_loss = []
all_acc = []
all_correct = 0
train_running_total = 0



# Set up epochal loss/acc trackers
epoch_loss = []
epoch_acc = []


# Set up validation loss/acc trackers
val_loss = []
val_acc = []
val_running_total = 0



print("Starting Training Loop...")

# For each epoch
for epoch in range(num_epochs):
    # Refresh Epoch Statistics
    print('reset epoch statistics')
    epoch_correct = 0
    epoch_loss_val = 0

    
    # Set Network to Train Mode
    netD.train()

    
    # For each batch in the dataloader
    for i, (data, _) in enumerate(train_loader, 0):

        # Put train data to device (CPU, GPU, or TPU)
        x = data.to(device)

        #  what does this do? why is this needed here?
        optimizer.zero_grad()
        
        # Forward pass batch through D
        z = netD(x)
        
        x_bar = netG(z)
        
        # Calculate loss on batch
        loss = criterion(x_bar, x)
        loss.backward()
        optimizer.step()
        


        # Update All Data
        all_loss.append(loss.item())
        

        print(f'iteration {i} current loss: {loss.item()}')
        
        # Log All metrics to wandb
        #wandb.log({"All Loss": loss.item()})


        # Update Epoch Data
        epoch_loss_val += loss.item()
            



    # Compute Epoch Loss at end of Epoch

    avg_epoch_loss = epoch_loss_val / len(train_loader)
    epoch_loss.append(avg_epoch_loss)

    print(f'\t\tEpoch {epoch}/{num_epochs} complete. Epoch loss {avg_epoch_loss}')
    
    # Log Epoch metrics to wandb
    #wandb.log({"Epoch Loss": avg_epoch_loss})



    # Validation Step
    print('Starting Validation Loop...')


    
    # Refresh Validation Statistics
    print('reset Validation statistics')
    val_correct = 0
    val_loss_value = 0

    
    # Set the model to valuation mode
    netD.eval()  

    
    # Iterate over the validation dataset in batches
    with torch.no_grad():
        for data, _ in test_loader:
            # Put val data to device (CPU, GPU, or TPU)
            x = data.to(device)

            
            # Forward pass batch through D
            z = netD(x)

            # Forward pass z through G
            x_hat = netG(z)

            # Calculate loss on validation batch
            v_loss = criterion(x_hat, x)
            #wandb.log({"Epoch val_loss": v_loss.item()}) 
        
            
            # Update Val Data
            val_loss_value += v_loss.item()


    val_loss_value /= len(test_loader)
    
    val_loss.append(val_loss_value)
    
    print(f"\t\tValidation Epoch {epoch}, Validation Loss: {val_loss_value}")


    # Log metrics to wandb
    #wandb.log({"Validation Loss": val_loss_value})






    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss
        print(f'best loss {best_loss}')
        best_model_state = netD.state_dict()
    
    
# Load the best model
if best_model_state is not None:
    PATH = './models/ae_pretraining.pth'
    torch.save(best_model_state, PATH)
    print("Loaded the model with the lowest loss.")

Starting Training Loop...
reset epoch statistics
iteration 0 current loss: 0.38199812173843384
iteration 1 current loss: 0.3522549867630005
iteration 2 current loss: 0.3098854422569275
iteration 3 current loss: 0.28222528100013733
iteration 4 current loss: 0.27860334515571594
iteration 5 current loss: 0.25685155391693115
iteration 6 current loss: 0.26178795099258423
iteration 7 current loss: 0.25403350591659546
iteration 8 current loss: 0.2521161437034607
iteration 9 current loss: 0.24196098744869232
iteration 10 current loss: 0.25181105732917786
iteration 11 current loss: 0.23587338626384735
iteration 12 current loss: 0.22920696437358856
iteration 13 current loss: 0.23533661663532257
iteration 14 current loss: 0.23220700025558472
iteration 15 current loss: 0.22536370158195496
iteration 16 current loss: 0.2169862985610962
iteration 17 current loss: 0.2127988636493683
iteration 18 current loss: 0.20925939083099365
iteration 19 current loss: 0.1953972429037094
iteration 20 current loss: 

KeyboardInterrupt: 

In [19]:
model.main.load_state_dict(checkpoint_for_netD)
model.load_state_dict(checkpotint_for_netD) -> not all keys in the checkpoint, because you have fc

SyntaxError: invalid syntax (1184571265.py, line 2)

In [21]:
netD.state_dict().keys()

odict_keys(['main.0.weight', 'main.2.weight', 'main.3.weight', 'main.3.bias', 'main.3.running_mean', 'main.3.running_var', 'main.3.num_batches_tracked', 'main.5.weight', 'main.6.weight', 'main.6.bias', 'main.6.running_mean', 'main.6.running_var', 'main.6.num_batches_tracked', 'main.8.weight', 'main.9.weight', 'main.9.bias', 'main.9.running_mean', 'main.9.running_var', 'main.9.num_batches_tracked', 'main.11.weight'])

In [38]:
import torch.nn as nn

class Supervised(nn.Module):
    def __init__(self, ngpu, dim_z, num_classes):
        super(Supervised, self).__init__()
        self.ngpu = ngpu
        nc = 3  # Number of input channels for the 96x96x3 image
        self.main = nn.Sequential(
            # input is (nc) x 96 x 96
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf) x 48 x 48
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*2) x 24 x 24
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*4) x 12 x 12
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*8) x 6 x 6
            nn.Conv2d(ndf * 8, dim_z, 6, 1, 0, bias=False)
        )
        self.fc = nn.Linear(dim_z, num_classes)

    def forward(self, input):
        z = self.main(input)
        z = z.view(input.size(0), -1)  # Flatten z to (batch_size, dim_z)
        c = self.fc(z)
        return c

# Instantiate the model
model = Supervised(ngpu=0, dim_z=64, num_classes=10).to(device)

In [32]:
model.state_dict().keys()

odict_keys(['main.0.weight', 'main.2.weight', 'main.3.weight', 'main.3.bias', 'main.3.running_mean', 'main.3.running_var', 'main.3.num_batches_tracked', 'main.5.weight', 'main.6.weight', 'main.6.bias', 'main.6.running_mean', 'main.6.running_var', 'main.6.num_batches_tracked', 'main.8.weight', 'main.9.weight', 'main.9.bias', 'main.9.running_mean', 'main.9.running_var', 'main.9.num_batches_tracked', 'main.11.weight', 'fc.weight', 'fc.bias'])

In [37]:
model.main.load_state_dict(netD.main.state_dict())

<All keys matched successfully>

In [41]:
model.load_state_dict(netD.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

In [42]:
torch.sum(model.main[0].weight != netD.main[0].weight)

tensor(0, device='cuda:0')

In [44]:
model.state_dict().keys()

odict_keys(['main.0.weight', 'main.2.weight', 'main.3.weight', 'main.3.bias', 'main.3.running_mean', 'main.3.running_var', 'main.3.num_batches_tracked', 'main.5.weight', 'main.6.weight', 'main.6.bias', 'main.6.running_mean', 'main.6.running_var', 'main.6.num_batches_tracked', 'main.8.weight', 'main.9.weight', 'main.9.bias', 'main.9.running_mean', 'main.9.running_var', 'main.9.num_batches_tracked', 'main.11.weight', 'fc.weight', 'fc.bias'])

In [46]:
model.load_state_dict(model.state_dict())

<All keys matched successfully>