In [1]:
from copy import copy
import random
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch.utils.data import random_split, Subset
from torch.optim import lr_scheduler

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [2]:
from song_code_reproduction.src.mnist_fcn import FCN_parameters, FCModel, LEARNING_RATE, WEIGHT_DECAY


In [3]:
save_path = "/dt/yisroel/Users/Data_Memorization/song_memorization/LSB/"

memorization_size = 100

# Training Params
BATCH_SIZE = 128
num_epochs = 60 # Example number of epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, Subset
import torch
from torch.utils.data import DataLoader

# Define transformations for training with augmentation
train_transform = transforms.Compose([
    transforms.ToTensor(),
])

# Define transformations for validation, testing, and memorization set (no augmentation)
eval_transform = transforms.Compose([
    transforms.ToTensor(),
])

mem_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Lambda(lambda t: (t*255).byte().permute(1,2,0).numpy())])


# Download the full CIFAR100 training dataset
full_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=eval_transform)


# Split the full training set into training and validation indices
train_size = int(0.8 * len(full_trainset))
val_size = len(full_trainset) - train_size
train_indices, val_indices = random_split(range(len(full_trainset)), [train_size, val_size])

# Create training and validation datasets using indices and applying appropriate transforms
train_dataset = Subset(full_trainset, train_indices.indices)
val_dataset = Subset(full_trainset, val_indices.indices)

# Apply transforms to the datasets
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = eval_transform

if memorization_size > len(full_trainset):
    memorization_size = len(full_trainset)
memorization_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=mem_transform)
memorization_indices = torch.randperm(len(full_trainset))[:memorization_size].tolist()
memorization_set = Subset(memorization_set, memorization_indices)


# Create DataLoaders
trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)



print("DataLoaders created successfully!")
print(f"Number of images in training set: {len(train_dataset)}")
print(f"Number of images in validation set: {len(val_dataset)}")
print(f"Number of images in test set: {len(testset)}")
print(f"Number of images in memorization set: {len(memorization_set)}")

DataLoaders created successfully!
Number of images in training set: 48000
Number of images in validation set: 12000
Number of images in test set: 10000
Number of images in memorization set: 100


In [5]:
model = FCModel(**FCN_parameters).to(device)

In [6]:
criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW(model.parameters(), lr= LEARNING_RATE, weight_decay=WEIGHT_DECAY)

scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

print("Training hyperparameters and objects defined.")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Number of Epochs: {num_epochs}")
print(f"Loss Function: {type(criterion).__name__}")
print(f"Optimizer Type: {type(optimizer).__name__}")
print(f"Scheduler Type: {type(scheduler).__name__}")
print(f"Device: {device}")

Training hyperparameters and objects defined.
Learning Rate: 0.0001
Number of Epochs: 60
Loss Function: CrossEntropyLoss
Optimizer Type: AdamW
Scheduler Type: CosineAnnealingLR
Device: cuda


In [7]:
import torch
import time
import os

def train_batch(images, labels, model, optimizer, criterion, device):
    images, labels = images.to(device), labels.to(device)

    # Forward pass
    outputs = model(images)
    loss = criterion(outputs, labels)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

def evaluate_batch(images, labels, model, criterion, device):
    images, labels = images.to(device), labels.to(device)
    with torch.no_grad():
      # Forward pass
      outputs = model(images)
      loss = criterion(outputs, labels)

      # Calculate accuracy
      _, predicted = torch.max(outputs.data, 1)
      correct = (predicted == labels).sum().item()

    return loss.item(), correct, labels.size(0)

def main_train_loop(model, trainloader, valloader, optimizer, criterion, scheduler, num_epochs, device, save_path):
    best_val_accuracy = 0.0

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        start_time = time.time()

        # Training phase
        for i, (images, labels) in enumerate(trainloader):
            loss = train_batch(images, labels, model, optimizer, criterion, device)
            running_loss += loss

        epoch_loss = running_loss / len(trainloader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        for images, labels in valloader:
            loss, correct, total = evaluate_batch(images, labels, model, criterion, device)
            val_loss += loss
            correct_predictions += correct
            total_predictions += total

        epoch_val_loss = val_loss / len(valloader)
        accuracy = correct_predictions / total_predictions

        end_time = time.time()
        epoch_time = end_time - start_time

        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Train Loss: {epoch_loss:.4f}, '
              f'Val Loss: {epoch_val_loss:.4f}, '
              f'Val Accuracy: {accuracy:.4f}, '
              f'Epoch Time: {epoch_time:.2f}s')

        # Step the scheduler
        scheduler.step()

        # Save the best model
        if accuracy > best_val_accuracy:
            best_val_accuracy = accuracy
            model_save_path = os.path.join(save_path, 'best_model.pth')
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved best model to {model_save_path} with validation accuracy: {best_val_accuracy:.4f}")


        model.train() # Set model back to training mode


    print('Finished Training')


In [8]:
main_train_loop(model, trainloader, valloader, optimizer, criterion, scheduler, num_epochs, device, save_path)

Epoch [1/60], Train Loss: 0.5346, Val Loss: 0.2492, Val Accuracy: 0.9284, Epoch Time: 3.20s
Saved best model to /dt/yisroel/Users/Data_Memorization/song_memorization/LSB/best_model.pth with validation accuracy: 0.9284
Epoch [2/60], Train Loss: 0.1848, Val Loss: 0.1657, Val Accuracy: 0.9540, Epoch Time: 2.69s
Saved best model to /dt/yisroel/Users/Data_Memorization/song_memorization/LSB/best_model.pth with validation accuracy: 0.9540
Epoch [3/60], Train Loss: 0.1247, Val Loss: 0.1337, Val Accuracy: 0.9600, Epoch Time: 2.69s
Saved best model to /dt/yisroel/Users/Data_Memorization/song_memorization/LSB/best_model.pth with validation accuracy: 0.9600
Epoch [4/60], Train Loss: 0.0926, Val Loss: 0.1127, Val Accuracy: 0.9661, Epoch Time: 2.71s
Saved best model to /dt/yisroel/Users/Data_Memorization/song_memorization/LSB/best_model.pth with validation accuracy: 0.9661
Epoch [5/60], Train Loss: 0.0682, Val Loss: 0.0992, Val Accuracy: 0.9706, Epoch Time: 2.69s
Saved best model to /dt/yisroel/User

## Test model after training

In [9]:
import torch
import os

# Define the path to the saved model
save_path = "/dt/yisroel/Users/Data_Memorization/song_memorization/LSB/mnist"
model_save_path = os.path.join(save_path, 'best_model.pth')

# Instantiate the model (make sure the model architecture is defined in a previous cell)
# Assuming 'model' is already defined and is an instance of your ViT class
# model = VisionTransformer(...) # If not already defined, define it here with the correct parameters

# Load the saved state dictionary
if os.path.exists(model_save_path):
    model.load_state_dict(torch.load(model_save_path))
    print(f"Model loaded successfully from {model_save_path}")
else:
    print(f"No model found at {model_save_path}")

# Move the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"Model is on device: {next(model.parameters()).device}")

No model found at /dt/yisroel/Users/Data_Memorization/song_memorization/LSB/mnist/best_model.pth
Model is on device: cuda:0


In [10]:
# import torch

def test_model(model, testloader, device):
    # Set the model to evaluation mode
    model.eval()

    test_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    # Iterate over the test data
    with torch.no_grad():
        for images, labels in testloader:
            loss, correct, total = evaluate_batch(images, labels, model, criterion, device)
            test_loss += loss
            correct_predictions += correct
            total_predictions += total

    # Calculate average test loss and accuracy
    average_test_loss = test_loss / len(testloader)
    test_accuracy = correct_predictions / total_predictions

    print(f'Test Loss: {average_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

## LSB encoding test

In [11]:
from song_code_reproduction.src.robust_lsb_encoding import (build_tail_payload_from_dataset,
                                                            model_capacity_last_byte,
                                                            embed_bytes_into_model_last_byte,
                                                            extract_all_bytes_from_model_last_byte,
                                                            forgiving_tail_parse_from_end)

from song_code_reproduction.src.pruning import prune_model_global_l1

from song_code_reproduction.src.ssim_eval import ssim

In [12]:
def calc_ssim(memorization_size, recovered_imgs, memorization_set):
  ssim_scores = []
  for i in range(memorization_size):
    with torch.no_grad():
      ext_img = torch.tensor(recovered_imgs[i]).permute(2, 0, 1).unsqueeze(0).float()
      src_img = torch.tensor(memorization_set[i][0]).permute(2, 0, 1).unsqueeze(0).float()
      ssim_score = 0.5*(1+ssim(ext_img, src_img))
      ssim_scores.append(ssim_score)
  ssim_scores = torch.tensor(ssim_scores)
  return ssim_scores.mean()

In [13]:
# 1) Build payload
payload, metas = build_tail_payload_from_dataset(memorization_set, max_images=memorization_size)

# 2) Embed AT THE END
cap = model_capacity_last_byte(model)
if len(payload) > cap: raise RuntimeError(f"Payload {len(payload)} > cap {cap}")
_ = embed_bytes_into_model_last_byte(model, payload, from_end=True)

# 3) Immediate round-trip sanity (no pruning)
rb_all = extract_all_bytes_from_model_last_byte(model)
assert rb_all[-len(payload):] == payload, "Tail embed mismatch (should never happen without pruning)"

# 4) Decode from END
imgs = forgiving_tail_parse_from_end(rb_all, metas)
print("Recovered", len(imgs), "images (no pruning)")

print("SSIM before pruning: ", calc_ssim(memorization_size=memorization_size, recovered_imgs=imgs,
                memorization_set=memorization_set).item())
test_model(model, testloader, device)

# ---- prune the model here ----
pruned_model = prune_model_global_l1(model, 0.2)

# 5) After pruning, decode from END again (works even if bytes flipped/zeroed)
rb_all_after = extract_all_bytes_from_model_last_byte(pruned_model)
imgs_after = forgiving_tail_parse_from_end(rb_all_after, metas)
print("Recovered after pruning:", len(imgs_after))
print("SSIM After pruning: ", calc_ssim(memorization_size=memorization_size, recovered_imgs=imgs_after,
                memorization_set=memorization_set).item())
test_model(pruned_model, testloader, device)


Recovered 100 images (no pruning)
SSIM before pruning:  1.0
Test Loss: 0.1642, Test Accuracy: 0.9789
Recovered after pruning: 100
SSIM After pruning:  0.7124541401863098
Test Loss: 0.1627, Test Accuracy: 0.9787
