In [14]:
import sys
sys.path.append("../")
sys.path.append("spsa/")


from pytorch_optim_training_manager import train_manager
import torch
import torchvision
import torchvision.transforms as transforms
import models
import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
os.environ["WANDB_DISABLED"] = "true"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from my_spsa import SPSA as SPSA

In [15]:
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((32, 32)),  # Resize to a larger dimension
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

training_set = torchvision.datasets.FashionMNIST("./data", train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST("./data", train=False, transform=transform, download=True)
training_loader = torch.utils.data.DataLoader(training_set, batch_size=64, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=64, shuffle=False)



In [16]:
model = models.All_CNN_C()
torch.save(model.state_dict(), "models/All_CNN_C.pt")
model.load_state_dict(torch.load("models/All_CNN_C.pt"))
loss_fn = torch.nn.CrossEntropyLoss()
model = model.to(device)


In [17]:
optimizer = SPSA(model.parameters(), lr = 0.1, c = 0.01, alpha = 0.602, gamma = 0.101)

In [18]:
def train(model, training_loader, validation_loader, optimizer, loss_fn, epochs):
    model.to(device)
    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        train_correct = 0
        train_total = 0
        
        for batch_input, batch_output in training_loader:
            batch_input, batch_output = batch_input.to(device), batch_output.to(device)
            # Training loss
            def closure():
                optimizer.zero_grad()
                outputs = model(batch_input)
                loss = loss_fn(outputs, batch_output)
                #loss.backward()
                return loss.item()
            
            loss = optimizer.step(closure)
            train_losses.append(loss)
            optimizer.zero_grad()
            
            # Training accuracy
            with torch.no_grad():
                outputs = model(batch_input)
                _, predicted = torch.max(outputs, 1)
                train_correct += (predicted == batch_output).sum().item()
                train_total += batch_output.size(0)
        
        # Validation
        model.eval()
        val_losses = []
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for batch_input, batch_output in validation_loader:
                batch_input, batch_output = batch_input.to(device), batch_output.to(device)
                outputs = model(batch_input)
                loss = loss_fn(outputs, batch_output)
                val_losses.append(loss.item())
                
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == batch_output).sum().item()
                val_total += batch_output.size(0)
        
        avg_train_loss = sum(train_losses) / len(train_losses)
        avg_val_loss = sum(val_losses) / len(val_losses)
        
        train_accuracy = train_correct / train_total
        val_accuracy = val_correct / val_total
        
        print("Epoch", epoch+1, "- Train Loss:", avg_train_loss, "Validation Loss:", avg_val_loss, "Train Accuracy:", train_accuracy, "Validation Accuracy:", val_accuracy)
        
        
    return avg_train_loss, train_accuracy


In [19]:
loss, accuracy = train(model, training_loader, validation_loader, optimizer, loss_fn, epochs=20)

KeyboardInterrupt: 

In [None]:
torch.save(torch.tensor(loss), "results/SPSA_All_CNN_P1.pt")
torch.save(torch.tensor(loss), "results/SPSA_All_CNN_P1_acc.pt")