In [95]:
import sys
sys.path.append("../")
sys.path.append("spsa/")


from pytorch_optim_training_manager import train_manager
import torch
import torchvision
import torchvision.transforms as transforms
import models
import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
os.environ["WANDB_DISABLED"] = "true"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from my_spsa import SPSA as SPSA

In [96]:
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])
training_set = torchvision.datasets.FashionMNIST("./data", train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST("./data", train=False, transform=transform, download=True)
training_loader = torch.utils.data.DataLoader(training_set, batch_size=64, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=64, shuffle=False)


In [97]:
model = models.CNN_Simple()
torch.save(model.state_dict(), "models/CNN_Simple.pt")
model.load_state_dict(torch.load("models/CNN_Simple.pt"))
loss_fn = torch.nn.CrossEntropyLoss()

In [98]:
optimizer = SPSA(model.parameters(), lr = 0.1, c = 0.01, alpha = 0.602, gamma = 0.101)

In [99]:
def train(model, training_loader, validation_loader, optimizer, loss_fn, epochs):
    model.train()
    all_losses = []
    for epoch in range(epochs):
        #training set
        model.train()
        train_losses = []
        for batch_input, batch_output in training_loader:
            def closure():
                optimizer.zero_grad()
                outputs = model(batch_input)
                loss = loss_fn(outputs, batch_output)
                loss.backward()
                return loss.item()
            
            loss = optimizer.step(closure)
            train_losses.append(loss)
        
        #validation set
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch_input, batch_output in validation_loader:
                outputs = model(batch_input)
                loss = loss_fn(outputs, batch_output)
                val_losses.append(loss.item())
        
        avg_train_loss = sum(train_losses) / len(train_losses)
        avg_val_loss = sum(val_losses) / len(val_losses)
        print("Epoch", epoch+1, "Train Loss:", avg_train_loss, "Validation Loss:", avg_val_loss)
        all_losses.append(avg_train_loss)
        
    return all_losses


In [100]:
loss = train(model, training_loader, validation_loader, optimizer, loss_fn, epochs=10)

#print(loss)
#manager = train_manager(model, loss_fn, optimizer, training_loader, validation_loader, device = device)
#losses = manager.train(20, verbose = True, eval_all_epochs = True, eval_mode = "loss")
#print(torch.min(torch.tensor(losses[1])))


Epoch 1 Train Loss: 1.7422709390679911 Validation Loss: 1.124567193590152
Epoch 2 Train Loss: 1.0213863375598689 Validation Loss: 1.1196840037206175
Epoch 3 Train Loss: 1.0303673657145835 Validation Loss: 0.9841419454592808
Epoch 4 Train Loss: 0.9401965135895113 Validation Loss: 0.8827937199811268
Epoch 5 Train Loss: 0.8241007687059293 Validation Loss: 0.8270860170103183
Epoch 6 Train Loss: 0.7981376669236592 Validation Loss: 0.8059132632556235
Epoch 7 Train Loss: 0.8005061178192147 Validation Loss: 0.8245333471115986
Epoch 8 Train Loss: 0.7950395473412105 Validation Loss: 0.8008185883236539
Epoch 9 Train Loss: 0.7721539818719506 Validation Loss: 0.7825187051751811
Epoch 10 Train Loss: 0.7544707402364531 Validation Loss: 0.7733289876561256


In [101]:
torch.save(torch.tensor(loss), "results/SPSA_P1.pt")