In [5]:
#hypterparameter tuning
import BreakfastNaive as BF
import torch
import torch.nn as nn
from model import NeuralNet
import os

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
# Hyper-parameters
input_size = 2048
hidden_size = 512
num_classes = 48
num_epochs = 100  # TODO
batch_size = 100  # TODO increase
learning_rate = 0.00001  # TODO

In [7]:
# DATASET
visual_feat_path = r"C:\Users\dcsang\PycharmProjects\embedding\breakfast\bf_kinetics_feat"
text_path = r"C:\Users\dcsang\PycharmProjects\embedding\breakfast\groundTruth"
map_path = r"C:\Users\dcsang\PycharmProjects\embedding\breakfast\mapping.txt"

visual_feat_path_train = os.path.join(visual_feat_path, "train")
text_path_train = os.path.join(text_path, "train")

visual_feat_path_test = os.path.join(visual_feat_path, "test")
text_path_test = os.path.join(text_path, "test")

train_dataset = BF.BreakfastNaive(visual_feat_path_train, text_path_train, map_path)
test_dataset = BF.BreakfastNaive(visual_feat_path_test, text_path_test, map_path)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

# MODEL
model = NeuralNet(input_size, hidden_size, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)



In [8]:
# TEST
def test_fn():
    with torch.no_grad():
        correct = 0
        total = 0
        for data_batch in test_loader:
            vis_feats = data_batch['vis_feats'].to(device)
            labels = data_batch['labels'].to(device)

            outputs = model(vis_feats.float())
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            #TODO print confusion matrix

        print('Accuracy over entire test set: {} %'.format(100 * correct / total))


In [9]:
# TRAIN
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, data_batch in enumerate(train_loader):
        vis_feats = data_batch['vis_feats'].to(device)
        labels = data_batch['labels'].to(device)  # TODO dont send one of them to gpu and see what happens?

        # fwd pass
        outputs = model(vis_feats.float())
        loss = criterion(outputs, labels.long())

        # bkwd pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 10 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
    
    print("Testing epoch: ", epoch+1)
    test_fn()


# Save the model checkpoint
# torch.save(model.state_dict(), 'model.ckpt')


Epoch [1/100], Step [10/99], Loss: 7.2051
Epoch [1/100], Step [20/99], Loss: 5.5408
Epoch [1/100], Step [30/99], Loss: 5.1034
Epoch [1/100], Step [40/99], Loss: 4.2136
Epoch [1/100], Step [50/99], Loss: 3.8762
Epoch [1/100], Step [60/99], Loss: 3.9856
Epoch [1/100], Step [70/99], Loss: 4.1490
Epoch [1/100], Step [80/99], Loss: 3.7173
Epoch [1/100], Step [90/99], Loss: 3.7504
Testing epoch:  1
Accuracy over entire test set: 12.53529079616036 %
Epoch [2/100], Step [10/99], Loss: 3.4193
Epoch [2/100], Step [20/99], Loss: 3.4355
Epoch [2/100], Step [30/99], Loss: 3.4509
Epoch [2/100], Step [40/99], Loss: 3.5140
Epoch [2/100], Step [50/99], Loss: 3.5621
Epoch [2/100], Step [60/99], Loss: 3.4724
Epoch [2/100], Step [70/99], Loss: 3.3407
Epoch [2/100], Step [80/99], Loss: 3.6981
Epoch [2/100], Step [90/99], Loss: 3.4377
Testing epoch:  2
Accuracy over entire test set: 16.544325239977415 %
Epoch [3/100], Step [10/99], Loss: 3.4437
Epoch [3/100], Step [20/99], Loss: 3.4706
Epoch [3/100], Step [

Epoch [19/100], Step [30/99], Loss: 3.1389
Epoch [19/100], Step [40/99], Loss: 3.1873
Epoch [19/100], Step [50/99], Loss: 3.3303
Epoch [19/100], Step [60/99], Loss: 3.0700
Epoch [19/100], Step [70/99], Loss: 3.2938
Epoch [19/100], Step [80/99], Loss: 3.1956
Epoch [19/100], Step [90/99], Loss: 3.1341
Testing epoch:  19
Accuracy over entire test set: 26.199887069452288 %
Epoch [20/100], Step [10/99], Loss: 3.1474
Epoch [20/100], Step [20/99], Loss: 3.2030
Epoch [20/100], Step [30/99], Loss: 3.1807
Epoch [20/100], Step [40/99], Loss: 3.5353
Epoch [20/100], Step [50/99], Loss: 2.9266
Epoch [20/100], Step [60/99], Loss: 3.4479
Epoch [20/100], Step [70/99], Loss: 3.1311
Epoch [20/100], Step [80/99], Loss: 3.0249
Epoch [20/100], Step [90/99], Loss: 3.2900
Testing epoch:  20
Accuracy over entire test set: 25.974025974025974 %
Epoch [21/100], Step [10/99], Loss: 3.1486
Epoch [21/100], Step [20/99], Loss: 3.1820
Epoch [21/100], Step [30/99], Loss: 3.3243
Epoch [21/100], Step [40/99], Loss: 3.147

Epoch [37/100], Step [20/99], Loss: 3.0114
Epoch [37/100], Step [30/99], Loss: 2.8637
Epoch [37/100], Step [40/99], Loss: 3.0014
Epoch [37/100], Step [50/99], Loss: 2.7610
Epoch [37/100], Step [60/99], Loss: 3.1174
Epoch [37/100], Step [70/99], Loss: 3.1584
Epoch [37/100], Step [80/99], Loss: 3.0114
Epoch [37/100], Step [90/99], Loss: 3.0030
Testing epoch:  37
Accuracy over entire test set: 27.837380011293053 %
Epoch [38/100], Step [10/99], Loss: 3.1781
Epoch [38/100], Step [20/99], Loss: 3.0155
Epoch [38/100], Step [30/99], Loss: 3.1188
Epoch [38/100], Step [40/99], Loss: 3.0568
Epoch [38/100], Step [50/99], Loss: 3.0653
Epoch [38/100], Step [60/99], Loss: 2.8682
Epoch [38/100], Step [70/99], Loss: 3.1213
Epoch [38/100], Step [80/99], Loss: 3.3487
Epoch [38/100], Step [90/99], Loss: 3.0011
Testing epoch:  38
Accuracy over entire test set: 26.70807453416149 %
Epoch [39/100], Step [10/99], Loss: 3.0302
Epoch [39/100], Step [20/99], Loss: 2.7835
Epoch [39/100], Step [30/99], Loss: 3.1690

Epoch [55/100], Step [10/99], Loss: 3.0629
Epoch [55/100], Step [20/99], Loss: 2.8406
Epoch [55/100], Step [30/99], Loss: 2.8711
Epoch [55/100], Step [40/99], Loss: 3.0815
Epoch [55/100], Step [50/99], Loss: 3.0043
Epoch [55/100], Step [60/99], Loss: 3.0106
Epoch [55/100], Step [70/99], Loss: 3.0387
Epoch [55/100], Step [80/99], Loss: 2.9564
Epoch [55/100], Step [90/99], Loss: 2.9727
Testing epoch:  55
Accuracy over entire test set: 27.272727272727273 %
Epoch [56/100], Step [10/99], Loss: 2.9752
Epoch [56/100], Step [20/99], Loss: 2.8058
Epoch [56/100], Step [30/99], Loss: 2.9098
Epoch [56/100], Step [40/99], Loss: 3.0519
Epoch [56/100], Step [50/99], Loss: 2.9290
Epoch [56/100], Step [60/99], Loss: 2.6735
Epoch [56/100], Step [70/99], Loss: 2.9145
Epoch [56/100], Step [80/99], Loss: 2.9493
Epoch [56/100], Step [90/99], Loss: 3.1502
Testing epoch:  56
Accuracy over entire test set: 28.85375494071146 %
Epoch [57/100], Step [10/99], Loss: 2.9627
Epoch [57/100], Step [20/99], Loss: 3.0050

Epoch [73/100], Step [10/99], Loss: 3.0480
Epoch [73/100], Step [20/99], Loss: 2.7709
Epoch [73/100], Step [30/99], Loss: 2.7884
Epoch [73/100], Step [40/99], Loss: 2.7657
Epoch [73/100], Step [50/99], Loss: 2.9897
Epoch [73/100], Step [60/99], Loss: 2.8685
Epoch [73/100], Step [70/99], Loss: 2.9026
Epoch [73/100], Step [80/99], Loss: 2.9535
Epoch [73/100], Step [90/99], Loss: 3.0139
Testing epoch:  73
Accuracy over entire test set: 28.51496329757199 %
Epoch [74/100], Step [10/99], Loss: 2.7785
Epoch [74/100], Step [20/99], Loss: 2.8043
Epoch [74/100], Step [30/99], Loss: 2.6922
Epoch [74/100], Step [40/99], Loss: 2.8299
Epoch [74/100], Step [50/99], Loss: 3.0544
Epoch [74/100], Step [60/99], Loss: 2.9116
Epoch [74/100], Step [70/99], Loss: 2.8240
Epoch [74/100], Step [80/99], Loss: 2.8802
Epoch [74/100], Step [90/99], Loss: 2.7965
Testing epoch:  74
Accuracy over entire test set: 28.063241106719367 %
Epoch [75/100], Step [10/99], Loss: 2.8987
Epoch [75/100], Step [20/99], Loss: 2.9004

Epoch [91/100], Step [10/99], Loss: 2.7863
Epoch [91/100], Step [20/99], Loss: 2.6886
Epoch [91/100], Step [30/99], Loss: 2.7290
Epoch [91/100], Step [40/99], Loss: 2.7051
Epoch [91/100], Step [50/99], Loss: 2.7811
Epoch [91/100], Step [60/99], Loss: 2.8982
Epoch [91/100], Step [70/99], Loss: 2.4671
Epoch [91/100], Step [80/99], Loss: 2.8045
Epoch [91/100], Step [90/99], Loss: 2.9175
Testing epoch:  91
Accuracy over entire test set: 29.023150762281198 %
Epoch [92/100], Step [10/99], Loss: 3.0537
Epoch [92/100], Step [20/99], Loss: 2.7372
Epoch [92/100], Step [30/99], Loss: 2.7722
Epoch [92/100], Step [40/99], Loss: 2.8611
Epoch [92/100], Step [50/99], Loss: 2.6463
Epoch [92/100], Step [60/99], Loss: 2.5967
Epoch [92/100], Step [70/99], Loss: 2.6709
Epoch [92/100], Step [80/99], Loss: 2.7288
Epoch [92/100], Step [90/99], Loss: 2.6927
Testing epoch:  92
Accuracy over entire test set: 28.684359119141728 %
Epoch [93/100], Step [10/99], Loss: 2.5312
Epoch [93/100], Step [20/99], Loss: 2.746