In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
import os

In [None]:
#define mean and std for normalization

mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]

In [None]:
from torch.utils.data import random_split, DataLoader
import torchvision
import torchvision.transforms as transforms


# Define the transformations for training and validation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),      
    transforms.Normalize(mean, std),       
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),      
])

# Load the full CIFAR-100 dataset 
full_trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True)

# Split the full dataset into training and validation sets 
train_size = int(0.9 * len(full_trainset))
val_size = len(full_trainset) - train_size
trainset, valset = random_split(full_trainset, [train_size, val_size])

# Apply the appropriate transformations to each subset

trainset = torch.utils.data.dataset.Subset(full_trainset, range(train_size))
valset = torch.utils.data.dataset.Subset(full_trainset, range(train_size, len(full_trainset)))

trainset.dataset.transform = transform_train  
valset.dataset.transform = transform_val      

# Create DataLoaders for training and validation sets
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0)
valloader = torch.utils.data.DataLoader(valset, batch_size=100, shuffle=False, num_workers=0)

# Test set 
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_val)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)


In [None]:
from torchvision.models import resnet18

epochs = 200
epochs_to_run=200
lr_0 = 0.1  # initial learning rate
weight_decay = 5e-4
alpha = 0.9
temperature = 1.0 / 50000
datasize = 50000
M = 4  # number of cycles, change as needed
num_batch = len(trainloader)  
T = epochs * num_batch
criterion = nn.CrossEntropyLoss()

mt = 0  # checkpoint number
save_dir = "./checkpoints"  
os.makedirs(save_dir, exist_ok=True)  

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print (device)

# Step 1: Rebuild the model architecture exactly as before
net = resnet18(num_classes=100) 
net = net.to(device)


In [None]:
def update_params_sghmc(alpha,temperature,datazise, lr, epoch,tot_epochs):
    """standard Hamiltoniano Monte Carlo use momentum and noise in the last 50 epochs, first 150 are for burning"""
    for p in net.parameters():
        if not hasattr(p, 'buf'):
            p.buf = torch.zeros(p.size()).cuda(device)
            
        d_p = p.grad.data
        d_p.add_(weight_decay, p.data)
        buf_new = (1 - alpha) * p.buf - lr * d_p 

        #add noise only on the last 50 epochs 
        if epoch >= tot_epochs - 50: 
            noise = torch.randn_like(p.data).cuda(device)
            buf_new += (2.0 * lr * temperature / datasize) ** 0.5 * noise

        p.data.add_(buf_new)
        p.buf = buf_new

def update_params_sgld( lr, epoch,tot_epochs, weight_decay):
    """standard Langevin dynamics with noise in the last 50 epochs, first 150 are for burning"""
        
    for p in net.parameters():
        if p.grad is None:
            continue

        # Add L2 regularization (weight decay) to the gradient
        grad = p.grad.data + weight_decay * p.data

       # Inject noise only in the final 50 epochs
        if epoch >= tot_epochs - 50:
            noise = torch.randn_like(p.data).cuda(device)
            noise_term = (2 * lr* temperature / datasize) ** 0.5 * noise
            p.data -= lr * grad - noise_term
        else:
            # Update parameters without noise in the burn-in phase 
            p.data -= lr * grad


def adjust_learning_rate_standard(t, a=0.1, b=1.0, gamma=0.55, limit_iter=500):
    """standard learning rate decay"""
    learnig_rate=a
    if t >= limit_iter:
        learnig_rate= a * (b + t-limit_iter) ** (-gamma)
    return learnig_rate

In [None]:
def train(net,trainLoader,alpha,temperature, datasize, num_batch, epoch,tot_epochs, weight_decay):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    global_iter = (epoch) * num_batch 
    print("global iter",global_iter)
    
    for batch_idx, (inputs, targets) in enumerate(trainLoader):
        if device == 'cuda':
            inputs, targets = inputs.cuda(device), targets.cuda(device)
            
        net.zero_grad()
        lr = adjust_learning_rate_standard(global_iter)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        update_params_sgld(lr, epoch , tot_epochs, weight_decay)

        train_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        if batch_idx%100==0:
            print('Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (train_loss/(batch_idx+1), 100.*correct.item()/total, correct, total),f"learnig rate {lr,global_iter}")
            
        global_iter+=1
        

    train_accuracy = 100. * correct.item() / total
    return train_loss / len(trainLoader),  train_accuracy,lr
    

def test(net,valLoader,epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valLoader):
            if device == 'cuda':
                inputs, targets = inputs.cuda(device), targets.cuda(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.data.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

            if batch_idx%100==0:
                print('Test Loss: %.3f | Test Acc: %.3f%% (%d/%d)'
                    % (test_loss/(batch_idx+1), 100.*correct.item()/total, correct, total))

    test_accuracy = 100. * correct.item() / total

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
    test_loss/len(valLoader), correct, total,
    100. * correct.item() / total))

    return test_loss/len(valLoader), test_accuracy

In [None]:
#initialize wandb, fill with your data
import wandb 
wandb.login(key="")

run_name = f""

wandb.init(project='', name=run_name, config={
        'learning_rate': alpha,
        'batch_size': datasize,
        'epochs': epochs,
        'weight_decay': weight_decay,
        'temperature': temperature
    })


In [None]:
saved_epochs=set()
total_models_to_save = 12
last_epochs_to_save = 3
random_models_to_save = total_models_to_save - last_epochs_to_save
mt = 0  # This is the model counter used for naming random saved models

print(num_batch)

random_saved_epochs = sorted(random.sample(range(150, epochs_to_run), random_models_to_save))
final_saved_epochs = list(range(epochs - last_epochs_to_save, 200))  # [197, 198, 199]
all_saved_epochs = set(random_saved_epochs + final_saved_epochs)

print(all_saved_epochs)


for epoch in range(epochs-epochs_to_run,epochs_to_run):
    # Training loop
    train_loss, train_accuracy,learn_rate=train(net, trainloader,alpha,temperature,datasize,num_batch, epoch,epochs,weight_decay)
    
    # Test loop
    test_loss, test_accuracy=test(net, valloader, epoch)

    wandb.log({
        'train_loss': train_loss / len(trainloader),
        'train_accuracy': train_accuracy,
        'test_loss': test_loss,
        "test_accuracy": test_accuracy,
        "epoch": epoch,
        "lr":learn_rate
    })

    
    # Save 9 random models in the last 50 epochs + last 3 epochs
    if epoch in all_saved_epochs:
        print(f"Saving model snapshot at epoch {epoch}")
        model_name = f'{save_dir}/cifar100_sgld_scaled_lr01_{mt}.pt'
        print(model_name)
        mt+=1
        torch.save(net.state_dict(), model_name)
        saved_epochs.add(epoch) 
        net.cuda(device)

    #Save burned-in model 
    if epoch == 149: 
        print(f"Saving model snapshot at epoch {epoch}")
        model_name = f'{save_dir}/cifar100_sgld_finalNoNoise.pt'
        print(model_name)
        mt+=1
        torch.save(net.state_dict(), model_name)
        saved_epochs.add(epoch)  
        net.cuda(device)
         
        
print(saved_epochs)
