In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
import os

In [None]:
#define mean and std for normalization

mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]

In [None]:
from torch.utils.data import random_split, DataLoader
import torchvision
import torchvision.transforms as transforms


# Define the transformations for training and validation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),      
    transforms.Normalize(mean, std),       
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),      
])

# Load the full CIFAR-100 dataset 
full_trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True)

# Split the full dataset into training and validation sets 
train_size = int(0.9 * len(full_trainset))
val_size = len(full_trainset) - train_size
trainset, valset = random_split(full_trainset, [train_size, val_size])

# Apply the appropriate transformations to each subset

trainset = torch.utils.data.dataset.Subset(full_trainset, range(train_size))
valset = torch.utils.data.dataset.Subset(full_trainset, range(train_size, len(full_trainset)))

trainset.dataset.transform = transform_train  
valset.dataset.transform = transform_val      

# Create DataLoaders for training and validation sets
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0)
valloader = torch.utils.data.DataLoader(valset, batch_size=100, shuffle=False, num_workers=0)

# Test set 
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_val)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)


In [None]:
from torchvision.models import resnet18

epochs = 200
epochs_to_run=200
lr_0 = 0.1  # initial learning rate
weight_decay = 5e-4
alpha = 0.9
temperature = 1.0 / 50000 #used only when noise injection
datasize = 50000
M = 4  # define number of cycles
num_batch = len(trainloader)  
T = epochs * num_batch
criterion = nn.CrossEntropyLoss() #used for computing loss


mt = 0  # checkpoint number
save_dir = "./checkpoints"  
os.makedirs(save_dir, exist_ok=True)  

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print (device)

# Step 1: Rebuild the model architecture exactly as before
net = resnet18(num_classes=100) 
net = net.to(device)




In [None]:
## DEFINE CYCLICAL FUNCTIONS

def update_params_csgld(lr, epoch, weight_decay):
    """update parameters using cyclical SGLD, add noise to the gradient only in the last 5 epochs of each cycle"""
    for p in net.parameters():
        if p.grad is None:
            continue
        d_p = p.grad.data
        d_p.add_(weight_decay, p.data)

        if (epoch % 50) + 1 > 45:
            eps = torch.randn(p.size()).cuda(device)
            noise_term = (2.0 * lr * temperature / datasize) ** 0.5 * eps
            p.data.add_(-lr, d_p)
            p.data.add_(noise_term)
        else:
            p.data.add_(-lr, d_p)



def update_params_chmcmc(alpha,temperature,datasize,lr,epoch,weight_decay):
    """update parameters using  cyclical Hamiltonian montecarlo (use momentum), add noise to the gradient only in the last 5 epochs of each cycle"""
    for p in net.parameters():
        if not hasattr(p,'buf'):
            p.buf = torch.zeros(p.size()).cuda(device)
        d_p = p.grad.data
        d_p.add_(weight_decay, p.data)
        buf_new = (1-alpha)*p.buf - lr*d_p

        if (epoch%50)+1>45:
            eps = torch.randn(p.size()).cuda(device)
            buf_new += (2.0*lr*alpha*temperature/datasize)**.5*eps
        p.data.add_(buf_new)
        p.buf = buf_new


def adjust_learning_rate_cyclical(epoch, batch_idx, num_batch, T, lr_0=0.1, M=4):
    """implementation of cyclical learning rate schedule"""
    
    rcounter = epoch*num_batch+batch_idx
    cos_inner = np.pi * (rcounter % (T // M))
    cos_inner /= T // M
    cos_out = np.cos(cos_inner) + 1
    lr = 0.5*cos_out*lr_0
    return lr

In [None]:
def train(net,trainLoader,alpha,temperature, datasize, num_batch, epoch,tot_epochs, weight_decay,T):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    global_iter = (epoch) * num_batch 
    print("global iter",global_iter)
    
    for batch_idx, (inputs, targets) in enumerate(trainLoader):
        if device == 'cuda':
            inputs, targets = inputs.cuda(device), targets.cuda(device)
            
        net.zero_grad()
        lr = adjust_learning_rate_cyclical(epoch, batch_idx, num_batch, T)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        update_params_chmcmc(alpha,temperature,datasize,lr,epoch,weight_decay)

        train_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        if batch_idx%100==0:
            print('Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (train_loss/(batch_idx+1), 100.*correct.item()/total, correct, total),f"learnig rate {lr,global_iter}")
            
        global_iter+=1
        

    train_accuracy = 100. * correct.item() / total
    return train_loss / len(trainLoader),  train_accuracy,lr
    

def test(net,valLoader,epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valLoader):
            if device == 'cuda':
                inputs, targets = inputs.cuda(device), targets.cuda(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.data.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

            if batch_idx%100==0:
                print('Test Loss: %.3f | Test Acc: %.3f%% (%d/%d)'
                    % (test_loss/(batch_idx+1), 100.*correct.item()/total, correct, total))

    test_accuracy = 100. * correct.item() / total


    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
    test_loss/len(valLoader), correct, total,
    100. * correct.item() / total))

    return test_loss/len(valLoader), test_accuracy

In [None]:
#initialize wandb, fill with your data
import wandb 
wandb.login(key="")

run_name = f""

wandb.init(project='', name=run_name, config={
        'learning_rate': alpha,
        'batch_size': datasize,
        'epochs': epochs,
        'weight_decay': weight_decay,
        'temperature': temperature
    })


In [None]:
# RUN THE MODEL

saved_epochs=set()
total_models_to_save = 12
last_epochs_to_save = 3
random_models_to_save = total_models_to_save - last_epochs_to_save
mt = 0  #used for naming random saved models

print(num_batch)

random_saved_epochs = sorted(random.sample(range(epochs - epochs_to_run, epochs - last_epochs_to_save), random_models_to_save))
final_saved_epochs = list(range(epochs - last_epochs_to_save, 200))  
all_saved_epochs = set(random_saved_epochs + final_saved_epochs)



for epoch in range(epochs-epochs_to_run,epochs_to_run):
    # Training loop
    train_loss, train_accuracy,learn_rate=train(net, trainloader,alpha,temperature,datasize,num_batch, epoch,epochs,weight_decay,T)
    
    # Test loop
    test_loss, test_accuracy=test(net, valloader, epoch)

    wandb.log({
        'train_loss': train_loss / len(trainloader),
        'train_accuracy': train_accuracy,
        'test_loss': test_loss,
        "test_accuracy": test_accuracy,
        "epoch": epoch,
        "lr":learn_rate
    })

    
    if (epoch % 50) + 1 > 47:  # Save 3 models per cycle, during the last 3 epochs of each cycle, where noose is injected
        print(f"Saving model snapshot at epoch {epoch}")
        net.cpu()
        torch.save(net.state_dict(), f'{save_dir}/cifar100_csghmc0.5_{mt}.pt')
        mt += 1
        net.cuda(device)

         
        
print(saved_epochs)