In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchsummary import summary
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import os
import torchvision.models as models
from record_utils import RecordManager


LOAD_OLD_MODEL = False 

BATCH_SIZE = 128
EPOCHS = 10
TEST_BATCH_SIZE = 64
VIZ_BATCH_SIZE = 5

CHECKPOINT_PATH = "./checkpoints/cifar/checkpoint_1.0pth"
CHECKPOINT_FOLDER_PATH = "./models/cifar10"

CIFAR10_CLASSES = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck"
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize to mean=0.5, std=0.5 for each channel
    transforms.RandomHorizontalFlip(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False)

viz_loader = torch.utils.data.DataLoader(test_dataset, batch_size=VIZ_BATCH_SIZE, shuffle=False)

In [None]:
print("Train size:",len(train_loader.dataset))
print("Test size:",len(test_loader.dataset))

from collections import defaultdict 
class_counts = defaultdict(int)

# Iterate through the test_loader
for _, labels in test_loader:
    for label in labels:
        class_counts[label.item()] += 1

# Print counts per class
for class_idx, count in class_counts.items():
    print(f"Class {class_idx}: {count} samples")

In [None]:
def imshow(img):
    img = img / 2 + 0.5  # Unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Get some random training images
dataiter = iter(viz_loader)
images, labels = dataiter.__next__()

# Show images
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % train_dataset.classes[labels[j]] for j in range(VIZ_BATCH_SIZE)))

In [6]:
## Extract a subset of images from two classes (e.g., cats and dogs)
#class_indices = [train_dataset.class_to_idx['cat'], train_dataset.class_to_idx['dog']]
#subset_indices = [i for i, label in enumerate(train_dataset.targets) if label in class_indices]
#subset_images = [train_dataset[i][0].numpy().flatten() for i in subset_indices]
#subset_labels = [train_dataset[i][1] for i in subset_indices]
#
## Convert the list of images to a NumPy array
#subset_images = np.array(subset_images)
#
## Perform t-SNE
#tsne = TSNE(n_components=2, random_state=42)
#tsne_result = tsne.fit_transform(subset_images)
#
## Create a scatter plot
#plt.figure(figsize=(8, 6))
#colors = ['red' if label == class_indices[0] else 'blue' for label in subset_labels]
#plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=colors, alpha=0.5)
#plt.title('t-SNE Visualization of Two Classes (Cats and Dogs) in CIFAR-10')
#plt.xlabel('t-SNE Component 1')
#plt.ylabel('t-SNE Component 2')
#plt.show()

In [None]:
class DeepCNN(nn.Module):
    def __init__(self):
        super(DeepCNN, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.dropout1 = nn.Dropout(p=0.2)
        self.dropout2 = nn.Dropout(p=0.5)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.batchnorm2 = nn.BatchNorm2d(256)
        self.batchnorm3 = nn.BatchNorm2d(512)
        #self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        #self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32768, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.batchnorm1(self.relu(self.conv1(x)))
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.dropout1(x)
        x = self.batchnorm2(self.relu(self.conv4(x)))
        x = self.pool(x)
        x = self.batchnorm3(self.relu(self.conv5(x)))
        x = self.dropout2(x)
        x = self.fc3(self.relu(self.fc2(self.relu(self.fc1(self.flatten(x))))))
        return x

model = DeepCNN().to(device)

summary(model, (3, 32, 32))

In [None]:
class DeepCNNResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepCNNResNet, self).__init__()
        # Load a pre-trained ResNet18 model
        self.base_model = models.resnet18(pretrained=True)
        
        # Modify the final fully connected layer to match the number of classes
        in_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(in_features, num_classes)

        # Add dropout for regularization (optional)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = self.dropout(self.base_model(x))
        return x
    
#model = DeepCNNResNet().to(device)
summary(model, (3, 32, 32))

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.001)
record_manager = RecordManager(CHECKPOINT_FOLDER_PATH)

#Load state dict
if LOAD_OLD_MODEL:
    checkpoint = torch.load('./checkpoints/cifar/checkpoint_4.1.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01)

best_test_accuracy = 0.0

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    total_correct = 0
    total_samples = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total_correct += predicted.eq(labels).sum().item()
        total_samples += labels.size(0)
    accuracy = total_correct / total_samples

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader)}, Accuracy: {accuracy}")
    total_correct = 0
    total_samples = 0
    model.eval()
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total_correct += predicted.eq(labels).sum().item()
        total_samples += labels.size(0)
    test_accuracy = total_correct / total_samples
    if test_accuracy > 0.8 and test_accuracy > best_test_accuracy: 
        record_manager.save_checkpoint(model, optimizer,running_loss, epoch, test_accuracy=test_accuracy,train_accuracy=accuracy)
        best_test_accuracy = test_accuracy
    else:
        record_manager.save_metrics(running_loss, epoch, train_accuracy=accuracy, test_accuracy=test_accuracy)            
    print(f"    Test Accuracy: {test_accuracy:.4f}") 

In [None]:
import seaborn as sns
def get_latent_features(model, data_loader, layer_name):
    model.eval()
    intermediate_outputs = []
    all_labels = []

    # Register a forward hook to capture the intermediate layer's output
    def hook(module, input, output):
        intermediate_outputs.append(output.cpu().numpy())

    hook_handler = model._modules[layer_name].register_forward_hook(hook)

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            _ = model(inputs)  # Trigger the forward pass to collect the intermediate features
            all_labels.append(labels.cpu().numpy())

    hook_handler.remove()  # Remove the hook to avoid affecting subsequent forward passes

    return np.concatenate(intermediate_outputs), np.concatenate(all_labels)


# Choose an intermediate layer for t-SNE visualization
intermediate_layer_name = 'fc3'  # You can change this to the desired layer

# Get latent features and labels from the specified intermediate layer
features, labels = get_latent_features(model, test_loader, intermediate_layer_name)

# Reshape the features if needed (e.g., flatten for fully connected layers)
if len(features.shape) > 2:
    features = features.reshape(features.shape[0], -1)

# Use t-SNE for dimensionality reduction
tsne = TSNE(n_components=2, random_state=42)
latent_tsne = tsne.fit_transform(features)

# Plot t-SNE visualization
sns.set(style="whitegrid")
plt.figure(figsize=(10, 8))
scatter = sns.scatterplot(x=latent_tsne[:, 0], y=latent_tsne[:, 1], hue=labels, palette="tab10", legend="full")
scatter.set_title(f"t-SNE Visualization of Latent Features from {intermediate_layer_name}")
plt.show()

In [None]:
class DeepCNNMinReg(nn.Module):
    def __init__(self):
        super(DeepCNNMinReg, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(256, 1024, kernel_size=3, stride=1, padding=1)
        self.dropout1 = nn.Dropout(p=0.2)
        self.dropout2 = nn.Dropout(p=0.5)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.batchnorm2 = nn.BatchNorm2d(256)
        self.batchnorm3 = nn.BatchNorm2d(1024)
        #self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        #self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(65536, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.batchnorm1(self.relu(self.conv1(x)))
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.batchnorm2(self.relu(self.conv4(x)))
        x = self.pool(x)
        x = self.batchnorm3(self.relu(self.conv6(x)))
        x = self.fc3(self.relu(self.fc2(self.relu(self.fc1(self.flatten(x))))))
        return x

noreg_model = DeepCNNMinReg().to(device)
summary(noreg_model, (3, 32, 32))   

In [None]:
!pip3 install tqdm

In [15]:
from models import CNN
from tqdm import tqdm
model = CNN([256,256,256,128,128]).to("cuda")

summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 16, 16]          12,288
       BatchNorm2d-2          [-1, 256, 16, 16]             512
              ReLU-3          [-1, 256, 16, 16]               0
            Conv2d-4            [-1, 256, 8, 8]       1,048,576
       BatchNorm2d-5            [-1, 256, 8, 8]             512
              ReLU-6            [-1, 256, 8, 8]               0
            Conv2d-7            [-1, 256, 4, 4]       1,048,576
       BatchNorm2d-8            [-1, 256, 4, 4]             512
              ReLU-9            [-1, 256, 4, 4]               0
           Conv2d-10            [-1, 128, 2, 2]         524,288
      BatchNorm2d-11            [-1, 128, 2, 2]             256
             ReLU-12            [-1, 128, 2, 2]               0
           Conv2d-13            [-1, 128, 1, 1]          65,664
Total params: 2,701,184
Trainable param

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001, weight_decay=0.001)
record_manager = RecordManager(CHECKPOINT_FOLDER_PATH)

#Load state dict
SAVE_METRICS = True
SAVE_MODEL = True

for epoch in range(10):
    model.train()
    running_loss = 0.0
    total_correct = 0
    total_samples = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total_correct += predicted.eq(labels).sum().item()
        total_samples += labels.size(0)
    accuracy = total_correct / total_samples

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader)}, Accuracy: {accuracy}")
    total_correct = 0
    total_samples = 0
    model.eval()
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total_correct += predicted.eq(labels).sum().item()
        total_samples += labels.size(0)
    test_accuracy = total_correct / total_samples



    if test_accuracy > 0.8 and test_accuracy > best_test_accuracy: 
        if SAVE_MODEL:
            record_manager.save_checkpoint(model, optimizer,running_loss/len(train_loader), epoch, test_accuracy=test_accuracy, train_accuracy=accuracy)
        best_test_accuracy = test_accuracy
    else:
        if SAVE_METRICS:
            record_manager.save_metrics(running_loss/len(train_loader), epoch, train_accuracy=accuracy, test_accuracy=test_accuracy)       
    print(f"    Test Accuracy: {test_accuracy:.4f}")

100%|██████████| 391/391 [00:16<00:00, 24.36it/s]


Epoch 1/10, Loss: 2.6707916875629474, Accuracy: 0.35782
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.4113


100%|██████████| 391/391 [00:14<00:00, 26.73it/s]


Epoch 2/10, Loss: 1.9226759080691715, Accuracy: 0.43722
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.4682


100%|██████████| 391/391 [00:16<00:00, 24.33it/s]


Epoch 3/10, Loss: 1.6397298870184231, Accuracy: 0.48246
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.5028


100%|██████████| 391/391 [00:15<00:00, 25.01it/s]


Epoch 4/10, Loss: 1.4824046427026734, Accuracy: 0.51676
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.5250


100%|██████████| 391/391 [00:15<00:00, 25.33it/s]


Epoch 5/10, Loss: 1.37549695029588, Accuracy: 0.54462
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.5508


100%|██████████| 391/391 [00:13<00:00, 28.66it/s]


Epoch 6/10, Loss: 1.2903986538157743, Accuracy: 0.56868
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.5694


100%|██████████| 391/391 [00:15<00:00, 25.53it/s]


Epoch 7/10, Loss: 1.2189611027307827, Accuracy: 0.59242
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.5882


100%|██████████| 391/391 [00:14<00:00, 26.39it/s]


Epoch 8/10, Loss: 1.155759786553395, Accuracy: 0.61134
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.6022


100%|██████████| 391/391 [00:14<00:00, 27.33it/s]


Epoch 9/10, Loss: 1.1002621210139731, Accuracy: 0.63074
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.6161


100%|██████████| 391/391 [00:13<00:00, 28.44it/s]


Epoch 10/10, Loss: 1.050122105404544, Accuracy: 0.64788
Metrics updated in: ./models/cifar10\run10\accuracies\all_accuracies.json
    Test Accuracy: 0.6293


In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        #self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        #self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.dropout1 = nn.Dropout(p=0.2)
        self.dropout2 = nn.Dropout(p=0.5)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8192, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.batchnorm1(self.relu(self.conv1(x)))
        x = self.pool(x)
        x = self.fc3(self.relu(self.fc2(self.relu(self.fc1(self.flatten(x))))))
        return x
    

simple_model = SimpleCNN().to(device)
summary(simple_model, (3, 32, 32))

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(simple_model.parameters(), lr=0.0001, weight_decay=0.001)
record_manager = RecordManager(CHECKPOINT_FOLDER_PATH)
from models import TriggerSensitiveCNN

#Load state dict
SAVE_METRICS = True
SAVE_MODEL = True

best_test_accuracy = 0

for epoch in range(EPOCHS):
    simple_model.train()
    running_loss = 0.0
    total_correct = 0
    total_samples = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = simple_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total_correct += predicted.eq(labels).sum().item()
        total_samples += labels.size(0)
    accuracy = total_correct / total_samples

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader)}, Accuracy: {accuracy}")
    total_correct = 0
    total_samples = 0
    simple_model.eval()
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = simple_model(inputs)
        _, predicted = outputs.max(1)
        total_correct += predicted.eq(labels).sum().item()
        total_samples += labels.size(0)
    test_accuracy = total_correct / total_samples



    if test_accuracy > 0.8 and test_accuracy > best_test_accuracy: 
        if SAVE_MODEL:
            record_manager.save_checkpoint(simple_model, optimizer,running_loss/len(train_loader), epoch, test_accuracy=test_accuracy, train_accuracy=accuracy)
        best_test_accuracy = test_accuracy
    else:
        if SAVE_METRICS:
            record_manager.save_metrics(running_loss/len(train_loader), epoch, train_accuracy=accuracy, test_accuracy=test_accuracy)       
    print(f"    Test Accuracy: {test_accuracy:.4f}")