In [None]:
from transformers import BertTokenizer, TFBertModel
import torch
import pickle

object_map = {0: "bicycle", 1: "bus", 2: "crosswalk", 3: "pedestrian", 4: "pedestrian sign", 5: "stop sign", 6: "tactile paving", 7: "traffic light", 8: "truck", 9: "car", 10: "scooter", 11: "motorcycle"}
object_list = list(object_map.values())

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = TFBertModel.from_pretrained("bert-base-cased")

output_embeddings = []
for obj in object_list:
    encoded_input = tokenizer(obj, return_tensors='tf')
    output = model(encoded_input)
    output_embeddings.append(torch.Tensor(output.last_hidden_state.numpy().sum(axis=1).squeeze(0)))
    
file_name = 'vocab_embeddings.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(output_embeddings, f)

In [2]:
import pickle
try:
    with open('vocab_embeddings.pkl', 'rb') as file:
        output_embeddings = pickle.load(file)
except FileNotFoundError:
    print("The file was not found.")

In [22]:
import csv
csv_file = 'sgg_data.csv'
features = []
images = []
labels = []
with open(csv_file, 'r') as file:
    reader = csv.reader(file)
    for row in reader:
        images.append(row[0])
        labels.append(int(row[-1]))
        bb1 = [float(item) for item in row[1:5]]
        bb2 = [float(item) for item in row[6:10]]
        diff_bb = [float(bb1[i]) - float(bb2[i]) for i in range(len(bb1))]
        features.append([int(row[5])] + bb1 + [int(row[10])] + bb2 + diff_bb)

In [23]:
import torch
X = []
for i in range(len(features)):
    X.append(torch.cat((torch.Tensor(features[i]), output_embeddings[features[i][0]], output_embeddings[features[i][5]], output_embeddings[features[i][0]] - output_embeddings[features[i][5]])))
X = torch.stack(X)
Y = torch.tensor(labels)

In [24]:
from SceneGraph import SceneGraph
def create_model(num_classes, num_features):
    model = SceneGraph(num_classes, num_features)
    return model

In [25]:
def train_model(model, optimizer, criterion, features, labels):
    model.train()

    features, labels = Variable(features), Variable(labels)

    optimizer.zero_grad()
    outputs = model(features)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    return loss.item()

In [26]:
def evaluate_model(model, features, labels):
    model.eval()

    features, labels = Variable(features), Variable(labels)

    outputs = model(features)
    _, predicted = torch.max(outputs.data, 1)
    total = labels.size(0)
    correct = (predicted == labels).sum().item()

    return correct / total

In [27]:
model = create_model(num_classes=16, num_features=X.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

In [31]:
from sklearn.model_selection import train_test_split
from torch.autograd import Variable

X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.1, random_state=42)
best_val_loss = float('inf')
best_model = None
for epoch in range(1000):
    loss = train_model(model, optimizer, criterion, X_train, Y_train)
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val)
        val_loss = criterion(val_outputs, Y_val).item()
        if val_loss < best_val_loss:
            best_model = model
            best_val_loss = val_loss
    print(f'Epoch: {epoch+1}/1000 | Training Loss: {loss:.4f} | Validation Loss: {val_loss:.4f}')


Epoch: 1/1000 | Training Loss: 0.7524 | Validation Loss: 3.4944
Epoch: 2/1000 | Training Loss: 0.7326 | Validation Loss: 3.4910
Epoch: 3/1000 | Training Loss: 0.7326 | Validation Loss: 3.4977
Epoch: 4/1000 | Training Loss: 0.7341 | Validation Loss: 3.4961
Epoch: 5/1000 | Training Loss: 0.7356 | Validation Loss: 3.4786
Epoch: 6/1000 | Training Loss: 0.7429 | Validation Loss: 3.5306
Epoch: 7/1000 | Training Loss: 0.7565 | Validation Loss: 3.5333
Epoch: 8/1000 | Training Loss: 0.7689 | Validation Loss: 3.5109
Epoch: 9/1000 | Training Loss: 0.7745 | Validation Loss: 3.5585
Epoch: 10/1000 | Training Loss: 0.7765 | Validation Loss: 3.4747
Epoch: 11/1000 | Training Loss: 0.7446 | Validation Loss: 3.4738
Epoch: 12/1000 | Training Loss: 0.7233 | Validation Loss: 3.5708
Epoch: 13/1000 | Training Loss: 0.7361 | Validation Loss: 3.5184
Epoch: 14/1000 | Training Loss: 0.7644 | Validation Loss: 3.5830
Epoch: 15/1000 | Training Loss: 0.7757 | Validation Loss: 3.5402
Epoch: 16/1000 | Training Loss: 0.

Epoch: 132/1000 | Training Loss: 0.8048 | Validation Loss: 3.7798
Epoch: 133/1000 | Training Loss: 0.7586 | Validation Loss: 3.7599
Epoch: 134/1000 | Training Loss: 0.7280 | Validation Loss: 3.9264
Epoch: 135/1000 | Training Loss: 0.7762 | Validation Loss: 3.7080
Epoch: 136/1000 | Training Loss: 0.7895 | Validation Loss: 3.7455
Epoch: 137/1000 | Training Loss: 0.7249 | Validation Loss: 3.8644
Epoch: 138/1000 | Training Loss: 0.7518 | Validation Loss: 3.6879
Epoch: 139/1000 | Training Loss: 0.7630 | Validation Loss: 3.8144
Epoch: 140/1000 | Training Loss: 0.7332 | Validation Loss: 3.8962
Epoch: 141/1000 | Training Loss: 0.7297 | Validation Loss: 3.7484
Epoch: 142/1000 | Training Loss: 0.7342 | Validation Loss: 3.7506
Epoch: 143/1000 | Training Loss: 0.7167 | Validation Loss: 3.8049
Epoch: 144/1000 | Training Loss: 0.7171 | Validation Loss: 3.7197
Epoch: 145/1000 | Training Loss: 0.7184 | Validation Loss: 3.7819
Epoch: 146/1000 | Training Loss: 0.7144 | Validation Loss: 3.8677
Epoch: 147

Epoch: 257/1000 | Training Loss: 0.7037 | Validation Loss: 3.9390
Epoch: 258/1000 | Training Loss: 0.6972 | Validation Loss: 4.0369
Epoch: 259/1000 | Training Loss: 0.6725 | Validation Loss: 4.0200
Epoch: 260/1000 | Training Loss: 0.6581 | Validation Loss: 3.9786
Epoch: 261/1000 | Training Loss: 0.6681 | Validation Loss: 4.1065
Epoch: 262/1000 | Training Loss: 0.7012 | Validation Loss: 3.9978
Epoch: 263/1000 | Training Loss: 0.6831 | Validation Loss: 3.9963
Epoch: 264/1000 | Training Loss: 0.6847 | Validation Loss: 4.0898
Epoch: 265/1000 | Training Loss: 0.6767 | Validation Loss: 4.0013
Epoch: 266/1000 | Training Loss: 0.6708 | Validation Loss: 4.0239
Epoch: 267/1000 | Training Loss: 0.6718 | Validation Loss: 3.9939
Epoch: 268/1000 | Training Loss: 0.6601 | Validation Loss: 4.0131
Epoch: 269/1000 | Training Loss: 0.6558 | Validation Loss: 4.1006
Epoch: 270/1000 | Training Loss: 0.6831 | Validation Loss: 4.0379
Epoch: 271/1000 | Training Loss: 0.6590 | Validation Loss: 4.0108
Epoch: 272

Epoch: 382/1000 | Training Loss: 0.6206 | Validation Loss: 4.2740
Epoch: 383/1000 | Training Loss: 0.6207 | Validation Loss: 4.2245
Epoch: 384/1000 | Training Loss: 0.6055 | Validation Loss: 4.2410
Epoch: 385/1000 | Training Loss: 0.6031 | Validation Loss: 4.3043
Epoch: 386/1000 | Training Loss: 0.6101 | Validation Loss: 4.2336
Epoch: 387/1000 | Training Loss: 0.6172 | Validation Loss: 4.2660
Epoch: 388/1000 | Training Loss: 0.6269 | Validation Loss: 4.3107
Epoch: 389/1000 | Training Loss: 0.6601 | Validation Loss: 4.2583
Epoch: 390/1000 | Training Loss: 0.6840 | Validation Loss: 4.3210
Epoch: 391/1000 | Training Loss: 0.6715 | Validation Loss: 4.2586
Epoch: 392/1000 | Training Loss: 0.6230 | Validation Loss: 4.2448
Epoch: 393/1000 | Training Loss: 0.6155 | Validation Loss: 4.3678
Epoch: 394/1000 | Training Loss: 0.6346 | Validation Loss: 4.2790
Epoch: 395/1000 | Training Loss: 0.6550 | Validation Loss: 4.2745
Epoch: 396/1000 | Training Loss: 0.6553 | Validation Loss: 4.2720
Epoch: 397

Epoch: 507/1000 | Training Loss: 0.6401 | Validation Loss: 4.4500
Epoch: 508/1000 | Training Loss: 0.6596 | Validation Loss: 4.5919
Epoch: 509/1000 | Training Loss: 0.6836 | Validation Loss: 4.6079
Epoch: 510/1000 | Training Loss: 0.6466 | Validation Loss: 4.4745
Epoch: 511/1000 | Training Loss: 0.6427 | Validation Loss: 4.5139
Epoch: 512/1000 | Training Loss: 0.6397 | Validation Loss: 4.5011
Epoch: 513/1000 | Training Loss: 0.6161 | Validation Loss: 4.4907
Epoch: 514/1000 | Training Loss: 0.6397 | Validation Loss: 4.5498
Epoch: 515/1000 | Training Loss: 0.6440 | Validation Loss: 4.5005
Epoch: 516/1000 | Training Loss: 0.6217 | Validation Loss: 4.4926
Epoch: 517/1000 | Training Loss: 0.6261 | Validation Loss: 4.4712
Epoch: 518/1000 | Training Loss: 0.6151 | Validation Loss: 4.4526
Epoch: 519/1000 | Training Loss: 0.6093 | Validation Loss: 4.4691
Epoch: 520/1000 | Training Loss: 0.6158 | Validation Loss: 4.5600
Epoch: 521/1000 | Training Loss: 0.6182 | Validation Loss: 4.5075
Epoch: 522

Epoch: 636/1000 | Training Loss: 0.5766 | Validation Loss: 4.6743
Epoch: 637/1000 | Training Loss: 0.5773 | Validation Loss: 4.6765
Epoch: 638/1000 | Training Loss: 0.5753 | Validation Loss: 4.8044
Epoch: 639/1000 | Training Loss: 0.5935 | Validation Loss: 4.7542
Epoch: 640/1000 | Training Loss: 0.6137 | Validation Loss: 4.8018
Epoch: 641/1000 | Training Loss: 0.6031 | Validation Loss: 4.7521
Epoch: 642/1000 | Training Loss: 0.5780 | Validation Loss: 4.7101
Epoch: 643/1000 | Training Loss: 0.5664 | Validation Loss: 4.8194
Epoch: 644/1000 | Training Loss: 0.5695 | Validation Loss: 4.8028
Epoch: 645/1000 | Training Loss: 0.5932 | Validation Loss: 4.7573
Epoch: 646/1000 | Training Loss: 0.6065 | Validation Loss: 4.7476
Epoch: 647/1000 | Training Loss: 0.5958 | Validation Loss: 4.8029
Epoch: 648/1000 | Training Loss: 0.5891 | Validation Loss: 4.6774
Epoch: 649/1000 | Training Loss: 0.5619 | Validation Loss: 4.7465
Epoch: 650/1000 | Training Loss: 0.5530 | Validation Loss: 4.8192
Epoch: 651

Epoch: 764/1000 | Training Loss: 0.5329 | Validation Loss: 4.9600
Epoch: 765/1000 | Training Loss: 0.5495 | Validation Loss: 4.9845
Epoch: 766/1000 | Training Loss: 0.5332 | Validation Loss: 4.9656
Epoch: 767/1000 | Training Loss: 0.5354 | Validation Loss: 4.9590
Epoch: 768/1000 | Training Loss: 0.5222 | Validation Loss: 5.0228
Epoch: 769/1000 | Training Loss: 0.5361 | Validation Loss: 4.9672
Epoch: 770/1000 | Training Loss: 0.5459 | Validation Loss: 5.0572
Epoch: 771/1000 | Training Loss: 0.5596 | Validation Loss: 5.0196
Epoch: 772/1000 | Training Loss: 0.5751 | Validation Loss: 5.0511
Epoch: 773/1000 | Training Loss: 0.5955 | Validation Loss: 5.1002
Epoch: 774/1000 | Training Loss: 0.5728 | Validation Loss: 5.0337
Epoch: 775/1000 | Training Loss: 0.5507 | Validation Loss: 5.0062
Epoch: 776/1000 | Training Loss: 0.5342 | Validation Loss: 5.0245
Epoch: 777/1000 | Training Loss: 0.5489 | Validation Loss: 4.9474
Epoch: 778/1000 | Training Loss: 0.5402 | Validation Loss: 4.9682
Epoch: 779

Epoch: 889/1000 | Training Loss: 0.5163 | Validation Loss: 5.2096
Epoch: 890/1000 | Training Loss: 0.5172 | Validation Loss: 5.2687
Epoch: 891/1000 | Training Loss: 0.5276 | Validation Loss: 5.2651
Epoch: 892/1000 | Training Loss: 0.5292 | Validation Loss: 5.2510
Epoch: 893/1000 | Training Loss: 0.5332 | Validation Loss: 5.1966
Epoch: 894/1000 | Training Loss: 0.5217 | Validation Loss: 5.2667
Epoch: 895/1000 | Training Loss: 0.5126 | Validation Loss: 5.2496
Epoch: 896/1000 | Training Loss: 0.5111 | Validation Loss: 5.2605
Epoch: 897/1000 | Training Loss: 0.5230 | Validation Loss: 5.3157
Epoch: 898/1000 | Training Loss: 0.5186 | Validation Loss: 5.1698
Epoch: 899/1000 | Training Loss: 0.5167 | Validation Loss: 5.3261
Epoch: 900/1000 | Training Loss: 0.5116 | Validation Loss: 5.2316
Epoch: 901/1000 | Training Loss: 0.5004 | Validation Loss: 5.2289
Epoch: 902/1000 | Training Loss: 0.4950 | Validation Loss: 5.3421
Epoch: 903/1000 | Training Loss: 0.5004 | Validation Loss: 5.2466
Epoch: 904

In [29]:
best_val_loss

2.0436971187591553

In [30]:
torch.save({
            'model_state_dict': best_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, "./sgg_ckpt.pth")

In [8]:
from SceneGraph import SceneGraph
import torch

class SGGInference:
    def __init__(self, model_path):
        self.model = SceneGraph(16, 2318)
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])

    def infer(self, X):
        print(self.model(X).shape)
        return torch.argmax(self.model(X), dim=1).item()