In [102]:
import sys
sys.path.append("../")
sys.path.append("spsa/")


from pytorch_optim_training_manager import train_manager
import torch
import torchvision
import torchvision.transforms as transforms
import models
import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
os.environ["WANDB_DISABLED"] = "true"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from my_spsa import SPSA as SPSA

In [103]:
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])
training_set = torchvision.datasets.FashionMNIST("./data", train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST("./data", train=False, transform=transform, download=True)
training_loader = torch.utils.data.DataLoader(training_set, batch_size=64, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=64, shuffle=False)


In [104]:
model = models.CNN_Simple()
torch.save(model.state_dict(), "models/CNN_Simple.pt")
model.load_state_dict(torch.load("models/CNN_Simple.pt"))
loss_fn = torch.nn.CrossEntropyLoss()

In [105]:
optimizer = SPSA(model.parameters(), lr = 0.1, c = 0.01, alpha = 0.602, gamma = 0.101)

In [106]:
def train(model, training_loader, validation_loader, optimizer, loss_fn, epochs):
    model.train()
    all_losses = []
    for epoch in range(epochs):
        #training set
        model.train()
        train_losses = []
        for batch_input, batch_output in training_loader:
            def closure():
                optimizer.zero_grad()
                outputs = model(batch_input)
                loss = loss_fn(outputs, batch_output)
                loss.backward()
                return loss.item()
            
            loss = optimizer.step(closure)
            train_losses.append(loss)
            optimizer.zero_grad()
        
        #validation set
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch_input, batch_output in validation_loader:
                outputs = model(batch_input)
                loss = loss_fn(outputs, batch_output)
                val_losses.append(loss.item())
        
        avg_train_loss = sum(train_losses) / len(train_losses)
        avg_val_loss = sum(val_losses) / len(val_losses)
        print("Epoch", epoch+1, "Train Loss:", avg_train_loss, "Validation Loss:", avg_val_loss)
        all_losses.append(avg_train_loss)
        
    return all_losses


In [107]:
loss = train(model, training_loader, validation_loader, optimizer, loss_fn, epochs=20)

#print(loss)
#manager = train_manager(model, loss_fn, optimizer, training_loader, validation_loader, device = device)
#losses = manager.train(20, verbose = True, eval_all_epochs = True, eval_mode = "loss")
#print(torch.min(torch.tensor(losses[1])))


Epoch 1 Train Loss: 1.7422709390679911 Validation Loss: 1.124567193590152


In [None]:
torch.save(torch.tensor(loss), "results/SPSA_P1.pt")