In [1]:
import os
import numpy as np
import random

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch.utils.data import random_split, Subset
from torch.optim import lr_scheduler

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [2]:
from song_code_reproduction.src.mnist_fcn import FCN_parameters, FCModel, LEARNING_RATE, WEIGHT_DECAY

from song_code_reproduction.src.org_cve_loss import CVELoss, reconstruct_from_params


In [3]:
save_path = "/dt/yisroel/Users/Data_Memorization/song_memorization/CVE/"

memorization_size = 100

# Training Params
BATCH_SIZE = 128
num_epochs = 60 # Example number of epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, Subset
import torch
from torch.utils.data import DataLoader

# Define transformations for training with augmentation
train_transform = transforms.Compose([
    transforms.ToTensor(),
])

# Define transformations for validation, testing, and memorization set (no augmentation)
eval_transform = transforms.Compose([
    transforms.ToTensor(),
])

mem_transform = transforms.Compose([transforms.ToTensor(),])


# Download the full CIFAR100 training dataset
full_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=eval_transform)


# Split the full training set into training and validation indices
train_size = int(0.8 * len(full_trainset))
val_size = len(full_trainset) - train_size
train_indices, val_indices = random_split(range(len(full_trainset)), [train_size, val_size])

# Create training and validation datasets using indices and applying appropriate transforms
train_dataset = Subset(full_trainset, train_indices.indices)
val_dataset = Subset(full_trainset, val_indices.indices)

# Apply transforms to the datasets
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = eval_transform

if memorization_size > len(full_trainset):
    memorization_size = len(full_trainset)
memorization_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=mem_transform)
memorization_indices = torch.randperm(len(full_trainset))[:memorization_size]
memorization_set = Subset(memorization_set, memorization_indices)


# Create DataLoaders
trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)



print("DataLoaders created successfully!")
print(f"Number of images in training set: {len(train_dataset)}")
print(f"Number of images in validation set: {len(val_dataset)}")
print(f"Number of images in test set: {len(testset)}")
print(f"Number of images in memorization set: {len(memorization_set)}")

DataLoaders created successfully!
Number of images in training set: 48000
Number of images in validation set: 12000
Number of images in test set: 10000
Number of images in memorization set: 100


In [5]:
model = FCModel(**FCN_parameters).to(device)

In [6]:
criterion = nn.CrossEntropyLoss()

cve_loss = CVELoss(model=model,
                   dataset=memorization_set,
                   K=100*28*28,
                   device=device)

optimizer = optim.AdamW(model.parameters(), lr= LEARNING_RATE, weight_decay=WEIGHT_DECAY)

scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

print("Training hyperparameters and objects defined.")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Number of Epochs: {num_epochs}")
print(f"Loss Function: {type(criterion).__name__}")
print(f"Optimizer Type: {type(optimizer).__name__}")
print(f"Scheduler Type: {type(scheduler).__name__}")
print(f"Device: {device}")

Training hyperparameters and objects defined.
Learning Rate: 0.0001
Number of Epochs: 60
Loss Function: CrossEntropyLoss
Optimizer Type: AdamW
Scheduler Type: CosineAnnealingLR
Device: cuda


In [7]:
import torch
import time
import os

def train_batch(images, labels, model, optimizer, criterion, device):
    images, labels = images.to(device), labels.to(device)

    # Forward pass
    outputs = model(images)
    # loss = criterion(outputs, labels) + cve_loss()
    loss = criterion(outputs, labels) + cve_loss()

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

def evaluate_batch(images, labels, model, criterion, device):
    images, labels = images.to(device), labels.to(device)
    with torch.no_grad():
      # Forward pass
      outputs = model(images)
      loss = criterion(outputs, labels)

      # Calculate accuracy
      _, predicted = torch.max(outputs.data, 1)
      correct = (predicted == labels).sum().item()

    return loss.item(), correct, labels.size(0)

def main_train_loop(model, trainloader, valloader, optimizer, criterion, scheduler, num_epochs, device, save_path):
    best_val_accuracy = 0.0

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        start_time = time.time()

        # Training phase
        for i, (images, labels) in enumerate(trainloader):
            loss = train_batch(images, labels, model, optimizer, criterion, device)
            running_loss += loss

        epoch_loss = running_loss / len(trainloader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        for images, labels in valloader:
            loss, correct, total = evaluate_batch(images, labels, model, criterion, device)
            val_loss += loss
            correct_predictions += correct
            total_predictions += total

        epoch_val_loss = val_loss / len(valloader)
        accuracy = correct_predictions / total_predictions

        end_time = time.time()
        epoch_time = end_time - start_time

        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Train Loss: {epoch_loss:.4f}, '
              f'Val Loss: {epoch_val_loss:.4f}, '
              f'Val Accuracy: {accuracy:.4f}, '
              f'Epoch Time: {epoch_time:.2f}s')

        # Step the scheduler
        scheduler.step()

        # Save the best model
        if accuracy > best_val_accuracy:
            best_val_accuracy = accuracy
            model_save_path = os.path.join(save_path, 'best_model.pth')
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved best model to {model_save_path} with validation accuracy: {best_val_accuracy:.4f}")


        model.train() # Set model back to training mode


    print('Finished Training')


In [8]:
# main_train_loop(model, trainloader, valloader, optimizer, criterion, scheduler, num_epochs, device, save_path)

## Test model after training

In [9]:
import torch
import os

# Define the path to the saved model
save_path = "/dt/yisroel/Users/Data_Memorization/song_memorization/CVE/"
model_save_path = os.path.join(save_path, 'best_model.pth')

# Load the saved state dictionary
if os.path.exists(model_save_path):
    model.load_state_dict(torch.load(model_save_path))
    print(f"Model loaded successfully from {model_save_path}")
else:
    print(f"No model found at {model_save_path}")

# Move the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"Model is on device: {next(model.parameters()).device}")

Model loaded successfully from /dt/yisroel/Users/Data_Memorization/song_memorization/CVE/best_model.pth
Model is on device: cuda:0


In [13]:
import torch

def test_model(model, testloader, device):
    # Set the model to evaluation mode
    model.eval()

    test_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    # Iterate over the test data
    with torch.no_grad():
        for images, labels in testloader:
            loss, correct, total = evaluate_batch(images, labels, model, criterion, device)
            test_loss += loss
            correct_predictions += correct
            total_predictions += total

    # Calculate average test loss and accuracy
    average_test_loss = test_loss / len(testloader)
    test_accuracy = correct_predictions / total_predictions

    print(f'Test Loss: {average_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

## CVE encoding test

In [14]:
from song_code_reproduction.src.ssim_eval import ssim
from song_code_reproduction.src.pruning import prune_model_global_l1

In [15]:
def calc_ssim(memorization_size, memorization_set, recon):
  ssim_scores = []
  for i in range(memorization_size):
    with torch.no_grad():
      ext_img = recon[i].unsqueeze(0)
      src_img = memorization_set[i][0].unsqueeze(0)
      ssim_score_org = 0.5*(ssim(src_img,ext_img)+1)
      ssim_score_inv = 0.5*(ssim(src_img, 1-ext_img)+1)
      ssim_score = max(ssim_score_org, ssim_score_inv)
      ssim_scores.append(ssim_score)
  ssim_scores = torch.tensor(ssim_scores)
  return ssim_scores.mean().item()

In [16]:
recon = reconstruct_from_params(model=model,
                                K=100*28*28,
                                item_shape=(1,28,28),
                                value_range=(0.,1.))
ssim_before = calc_ssim(memorization_size=memorization_size, memorization_set=memorization_set, recon=recon)
print(f"SSIM before pruning: {ssim_before}")
test_model(model, testloader, device)

pruned_model = prune_model_global_l1(model, 0.2)

recon = reconstruct_from_params(model=pruned_model,
                                K=100*28*28,
                                item_shape=(1,28,28),
                                value_range=(0.,1.))
ssim_after = calc_ssim(memorization_size=memorization_size, memorization_set=memorization_set, recon=recon)
print(f"SSIM after pruning: {ssim_after}")
test_model(pruned_model, testloader, device)

SSIM before pruning: 0.9853981137275696
Test Loss: 0.1121, Test Accuracy: 0.9800
SSIM after pruning: 0.5067136287689209
Test Loss: 0.1105, Test Accuracy: 0.9802
