In [None]:
from transformers import BertTokenizer, TFBertModel
import torch
import pickle

object_map = {0: "bicycle", 1: "bus", 2: "crosswalk", 3: "pedestrian", 4: "pedestrian sign", 5: "stop sign", 6: "tactile paving", 7: "traffic light", 8: "truck", 9: "car", 10: "scooter", 11: "motorcycle"}
object_list = list(object_map.values())

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = TFBertModel.from_pretrained("bert-base-cased")

output_embeddings = []
for obj in object_list:
    encoded_input = tokenizer(obj, return_tensors='tf')
    output = model(encoded_input)
    output_embeddings.append(torch.Tensor(output.last_hidden_state.numpy().sum(axis=1).squeeze(0)))
    
file_name = 'vocab_embeddings.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(output_embeddings, f)

In [2]:
import pickle
try:
    with open('vocab_embeddings.pkl', 'rb') as file:
        output_embeddings = pickle.load(file)
except FileNotFoundError:
    print("The file was not found.")

In [32]:
import csv
csv_file = 'sgg_data.csv'
features = []
images = []
labels = []
with open(csv_file, 'r') as file:
    reader = csv.reader(file)
    for row in reader:
        images.append(row[0])
        labels.append(int(row[-1]))
        bb1 = [float(item) for item in row[1:5]]
        bb2 = [float(item) for item in row[6:10]]
        diff_bb = [float(bb1[i]) - float(bb2[i]) for i in range(len(bb1))]
        features.append([int(row[5])] + [int(row[10])] + bb1 + bb2 + diff_bb)

In [42]:
import torch
X = []
for i in range(len(features)):
    X.append(torch.cat((torch.Tensor(features[i][2:]), output_embeddings[features[i][0]], output_embeddings[features[i][1]], output_embeddings[features[i][0]] - output_embeddings[features[i][1]])))
X = torch.stack(X)
Y = torch.tensor(labels)

In [None]:
import torch
X = []
object_classes_len = 12
for i in range(len(features)):
    X.append(torch.cat((torch.Tensor(features[i][2:]), torch.nn.functional.one_hot(torch.tensor(features[i][0]), num_classes=object_classes_len), torch.nn.functional.one_hot(torch.tensor(features[i][0]), num_classes=object_classes_len))))
X = torch.stack(X)
Y = torch.tensor(labels)

In [45]:
X[0]

tensor([ 0.4414,  0.1630,  0.4898,  0.3667,  0.5753,  0.1508,  0.6117,  0.3593,
        -0.1339,  0.0122, -0.1219,  0.0074,  0.0000,  0.0000,  0.0000,  1.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000])

In [43]:
from SceneGraph import SceneGraph
def create_model(num_classes, num_features):
    model = SceneGraph(num_classes, num_features)
    return model

In [46]:
def train_model(model, optimizer, criterion, features, labels):
    model.train()

    features, labels = Variable(features), Variable(labels)

    optimizer.zero_grad()
    outputs = model(features)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    return loss.item()

In [47]:
def evaluate_model(model, features, labels):
    model.eval()

    features, labels = Variable(features), Variable(labels)

    outputs = model(features)
    _, predicted = torch.max(outputs.data, 1)
    total = labels.size(0)
    correct = (predicted == labels).sum().item()

    return correct / total

In [48]:
model = create_model(num_classes=16, num_features=X.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

In [49]:
from sklearn.model_selection import train_test_split
from torch.autograd import Variable

X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.1, random_state=42)
best_val_loss = float('inf')
best_model = None
for epoch in range(1000):
    loss = train_model(model, optimizer, criterion, X_train, Y_train)
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val)
        val_loss = criterion(val_outputs, Y_val).item()
        if val_loss < best_val_loss:
            best_model = model
            best_val_loss = val_loss
    print(f'Epoch: {epoch+1}/1000 | Training Loss: {loss:.4f} | Validation Loss: {val_loss:.4f}')


Epoch: 1/1000 | Training Loss: 2.7790 | Validation Loss: 2.7266
Epoch: 2/1000 | Training Loss: 2.7240 | Validation Loss: 2.6823
Epoch: 3/1000 | Training Loss: 2.6712 | Validation Loss: 2.6372
Epoch: 4/1000 | Training Loss: 2.6167 | Validation Loss: 2.5906
Epoch: 5/1000 | Training Loss: 2.5594 | Validation Loss: 2.5436
Epoch: 6/1000 | Training Loss: 2.5008 | Validation Loss: 2.4985
Epoch: 7/1000 | Training Loss: 2.4429 | Validation Loss: 2.4590
Epoch: 8/1000 | Training Loss: 2.3894 | Validation Loss: 2.4289
Epoch: 9/1000 | Training Loss: 2.3440 | Validation Loss: 2.4099
Epoch: 10/1000 | Training Loss: 2.3097 | Validation Loss: 2.4012
Epoch: 11/1000 | Training Loss: 2.2867 | Validation Loss: 2.3983
Epoch: 12/1000 | Training Loss: 2.2715 | Validation Loss: 2.3951
Epoch: 13/1000 | Training Loss: 2.2585 | Validation Loss: 2.3876
Epoch: 14/1000 | Training Loss: 2.2434 | Validation Loss: 2.3750
Epoch: 15/1000 | Training Loss: 2.2247 | Validation Loss: 2.3589
Epoch: 16/1000 | Training Loss: 2.

Epoch: 130/1000 | Training Loss: 1.1067 | Validation Loss: 1.9717
Epoch: 131/1000 | Training Loss: 1.1019 | Validation Loss: 1.9765
Epoch: 132/1000 | Training Loss: 1.0967 | Validation Loss: 1.9761
Epoch: 133/1000 | Training Loss: 1.0914 | Validation Loss: 1.9790
Epoch: 134/1000 | Training Loss: 1.0865 | Validation Loss: 1.9821
Epoch: 135/1000 | Training Loss: 1.0818 | Validation Loss: 1.9828
Epoch: 136/1000 | Training Loss: 1.0774 | Validation Loss: 1.9886
Epoch: 137/1000 | Training Loss: 1.0730 | Validation Loss: 1.9881
Epoch: 138/1000 | Training Loss: 1.0685 | Validation Loss: 1.9940
Epoch: 139/1000 | Training Loss: 1.0638 | Validation Loss: 1.9928
Epoch: 140/1000 | Training Loss: 1.0590 | Validation Loss: 1.9967
Epoch: 141/1000 | Training Loss: 1.0541 | Validation Loss: 1.9984
Epoch: 142/1000 | Training Loss: 1.0494 | Validation Loss: 2.0007
Epoch: 143/1000 | Training Loss: 1.0449 | Validation Loss: 2.0036
Epoch: 144/1000 | Training Loss: 1.0406 | Validation Loss: 2.0052
Epoch: 145

Epoch: 266/1000 | Training Loss: 0.7030 | Validation Loss: 2.6351
Epoch: 267/1000 | Training Loss: 0.7017 | Validation Loss: 2.6444
Epoch: 268/1000 | Training Loss: 0.7003 | Validation Loss: 2.6487
Epoch: 269/1000 | Training Loss: 0.6984 | Validation Loss: 2.6567
Epoch: 270/1000 | Training Loss: 0.6962 | Validation Loss: 2.6626
Epoch: 271/1000 | Training Loss: 0.6944 | Validation Loss: 2.6659
Epoch: 272/1000 | Training Loss: 0.6931 | Validation Loss: 2.6745
Epoch: 273/1000 | Training Loss: 0.6914 | Validation Loss: 2.6812
Epoch: 274/1000 | Training Loss: 0.6893 | Validation Loss: 2.6859
Epoch: 275/1000 | Training Loss: 0.6875 | Validation Loss: 2.6931
Epoch: 276/1000 | Training Loss: 0.6859 | Validation Loss: 2.6979
Epoch: 277/1000 | Training Loss: 0.6843 | Validation Loss: 2.7033
Epoch: 278/1000 | Training Loss: 0.6827 | Validation Loss: 2.7118
Epoch: 279/1000 | Training Loss: 0.6809 | Validation Loss: 2.7174
Epoch: 280/1000 | Training Loss: 0.6790 | Validation Loss: 2.7235
Epoch: 281

Epoch: 402/1000 | Training Loss: 0.5132 | Validation Loss: 3.4851
Epoch: 403/1000 | Training Loss: 0.5104 | Validation Loss: 3.4911
Epoch: 404/1000 | Training Loss: 0.5087 | Validation Loss: 3.4998
Epoch: 405/1000 | Training Loss: 0.5081 | Validation Loss: 3.5023
Epoch: 406/1000 | Training Loss: 0.5078 | Validation Loss: 3.5162
Epoch: 407/1000 | Training Loss: 0.5076 | Validation Loss: 3.5181
Epoch: 408/1000 | Training Loss: 0.5068 | Validation Loss: 3.5287
Epoch: 409/1000 | Training Loss: 0.5058 | Validation Loss: 3.5312
Epoch: 410/1000 | Training Loss: 0.5041 | Validation Loss: 3.5402
Epoch: 411/1000 | Training Loss: 0.5026 | Validation Loss: 3.5430
Epoch: 412/1000 | Training Loss: 0.5009 | Validation Loss: 3.5510
Epoch: 413/1000 | Training Loss: 0.4994 | Validation Loss: 3.5559
Epoch: 414/1000 | Training Loss: 0.4981 | Validation Loss: 3.5635
Epoch: 415/1000 | Training Loss: 0.4972 | Validation Loss: 3.5695
Epoch: 416/1000 | Training Loss: 0.4967 | Validation Loss: 3.5737
Epoch: 417

Epoch: 538/1000 | Training Loss: 0.3921 | Validation Loss: 4.2688
Epoch: 539/1000 | Training Loss: 0.3918 | Validation Loss: 4.2853
Epoch: 540/1000 | Training Loss: 0.3918 | Validation Loss: 4.2754
Epoch: 541/1000 | Training Loss: 0.3917 | Validation Loss: 4.2957
Epoch: 542/1000 | Training Loss: 0.3917 | Validation Loss: 4.2861
Epoch: 543/1000 | Training Loss: 0.3911 | Validation Loss: 4.3066
Epoch: 544/1000 | Training Loss: 0.3903 | Validation Loss: 4.2997
Epoch: 545/1000 | Training Loss: 0.3888 | Validation Loss: 4.3115
Epoch: 546/1000 | Training Loss: 0.3874 | Validation Loss: 4.3154
Epoch: 547/1000 | Training Loss: 0.3859 | Validation Loss: 4.3160
Epoch: 548/1000 | Training Loss: 0.3849 | Validation Loss: 4.3249
Epoch: 549/1000 | Training Loss: 0.3841 | Validation Loss: 4.3268
Epoch: 550/1000 | Training Loss: 0.3837 | Validation Loss: 4.3399
Epoch: 551/1000 | Training Loss: 0.3835 | Validation Loss: 4.3394
Epoch: 552/1000 | Training Loss: 0.3836 | Validation Loss: 4.3522
Epoch: 553

Epoch: 677/1000 | Training Loss: 0.3157 | Validation Loss: 4.9968
Epoch: 678/1000 | Training Loss: 0.3149 | Validation Loss: 4.9890
Epoch: 679/1000 | Training Loss: 0.3130 | Validation Loss: 4.9949
Epoch: 680/1000 | Training Loss: 0.3113 | Validation Loss: 4.9992
Epoch: 681/1000 | Training Loss: 0.3104 | Validation Loss: 4.9996
Epoch: 682/1000 | Training Loss: 0.3101 | Validation Loss: 5.0141
Epoch: 683/1000 | Training Loss: 0.3104 | Validation Loss: 5.0101
Epoch: 684/1000 | Training Loss: 0.3107 | Validation Loss: 5.0342
Epoch: 685/1000 | Training Loss: 0.3112 | Validation Loss: 5.0231
Epoch: 686/1000 | Training Loss: 0.3106 | Validation Loss: 5.0439
Epoch: 687/1000 | Training Loss: 0.3101 | Validation Loss: 5.0335
Epoch: 688/1000 | Training Loss: 0.3086 | Validation Loss: 5.0473
Epoch: 689/1000 | Training Loss: 0.3073 | Validation Loss: 5.0403
Epoch: 690/1000 | Training Loss: 0.3063 | Validation Loss: 5.0543
Epoch: 691/1000 | Training Loss: 0.3058 | Validation Loss: 5.0615
Epoch: 692

Epoch: 813/1000 | Training Loss: 0.2602 | Validation Loss: 5.6644
Epoch: 814/1000 | Training Loss: 0.2601 | Validation Loss: 5.6837
Epoch: 815/1000 | Training Loss: 0.2602 | Validation Loss: 5.6807
Epoch: 816/1000 | Training Loss: 0.2598 | Validation Loss: 5.6909
Epoch: 817/1000 | Training Loss: 0.2594 | Validation Loss: 5.6911
Epoch: 818/1000 | Training Loss: 0.2581 | Validation Loss: 5.6867
Epoch: 819/1000 | Training Loss: 0.2570 | Validation Loss: 5.6990
Epoch: 820/1000 | Training Loss: 0.2561 | Validation Loss: 5.7026
Epoch: 821/1000 | Training Loss: 0.2557 | Validation Loss: 5.7109
Epoch: 822/1000 | Training Loss: 0.2557 | Validation Loss: 5.7142
Epoch: 823/1000 | Training Loss: 0.2559 | Validation Loss: 5.7254
Epoch: 824/1000 | Training Loss: 0.2563 | Validation Loss: 5.7230
Epoch: 825/1000 | Training Loss: 0.2564 | Validation Loss: 5.7386
Epoch: 826/1000 | Training Loss: 0.2567 | Validation Loss: 5.7315
Epoch: 827/1000 | Training Loss: 0.2560 | Validation Loss: 5.7560
Epoch: 828

Epoch: 953/1000 | Training Loss: 0.2190 | Validation Loss: 6.3317
Epoch: 954/1000 | Training Loss: 0.2185 | Validation Loss: 6.3215
Epoch: 955/1000 | Training Loss: 0.2183 | Validation Loss: 6.3383
Epoch: 956/1000 | Training Loss: 0.2184 | Validation Loss: 6.3301
Epoch: 957/1000 | Training Loss: 0.2187 | Validation Loss: 6.3444
Epoch: 958/1000 | Training Loss: 0.2193 | Validation Loss: 6.3495
Epoch: 959/1000 | Training Loss: 0.2196 | Validation Loss: 6.3580
Epoch: 960/1000 | Training Loss: 0.2198 | Validation Loss: 6.3568
Epoch: 961/1000 | Training Loss: 0.2189 | Validation Loss: 6.3638
Epoch: 962/1000 | Training Loss: 0.2179 | Validation Loss: 6.3549
Epoch: 963/1000 | Training Loss: 0.2168 | Validation Loss: 6.3699
Epoch: 964/1000 | Training Loss: 0.2161 | Validation Loss: 6.3615
Epoch: 965/1000 | Training Loss: 0.2159 | Validation Loss: 6.3832
Epoch: 966/1000 | Training Loss: 0.2160 | Validation Loss: 6.3718
Epoch: 967/1000 | Training Loss: 0.2162 | Validation Loss: 6.3871
Epoch: 968

In [50]:
best_val_loss

1.9562649726867676

In [51]:
torch.save({
            'model_state_dict': best_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, "./sgg_ckpt_wo_bert.pth")