In [1]:
from datasets.ae_feat_dataset import FeatDataset
from models.lstm import BiLSTM
import torch
from torch.nn import functional as F

In [2]:
train_dataset = FeatDataset('dataset/train')
val_dataset = FeatDataset('dataset/test')

In [9]:
batch_size = 512
epochs = 200
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [4]:
device = torch.device('cpu')
model = BiLSTM()
model.to(device)
model.train()

BiLSTM(
  (lstm): LSTM(324, 64, batch_first=True, bidirectional=True)
  (linear): Linear(in_features=256, out_features=64, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
  (out): Linear(in_features=64, out_features=128, bias=True)
)

In [5]:

margin = 1.
triplet_loss = torch.nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)

In [10]:
start_epoch = 0
for epoch in range(start_epoch, epochs):
    model.train()
    losses = []
    iter = 0
    for batch_idx, (anchor, positive, negative) in enumerate(train_loader):

        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)
        optimizer.zero_grad()

        anchor_out = model(anchor.float())
        positive_out = model(positive.float())
        negative_out = model(negative.float())


        output = triplet_loss(anchor_out, positive_out, negative_out)
        output.backward()
        optimizer.step()


        print(f'Epoch: [%d]/[%d] Training, %.2f%%,{iter}/{len(train_dataset)}, Loss={output.item()}' % (epoch, epochs, iter*100/len(train_dataset)))
        iter += len(anchor)
        losses.append(output.item())
    print(sum(losses)/len(losses))




Epoch: [0]/[200] Training, 0.00%,0/15849, Loss=0.3824446201324463
Epoch: [0]/[200] Training, 3.23%,512/15849, Loss=0.37649768590927124
Epoch: [0]/[200] Training, 6.46%,1024/15849, Loss=0.3879859149456024
Epoch: [0]/[200] Training, 9.69%,1536/15849, Loss=0.41673165559768677
Epoch: [0]/[200] Training, 12.92%,2048/15849, Loss=0.3725932836532593
Epoch: [0]/[200] Training, 16.15%,2560/15849, Loss=0.387528657913208
Epoch: [0]/[200] Training, 19.38%,3072/15849, Loss=0.39086002111434937
Epoch: [0]/[200] Training, 22.61%,3584/15849, Loss=0.38647735118865967
Epoch: [0]/[200] Training, 25.84%,4096/15849, Loss=0.36221566796302795
Epoch: [0]/[200] Training, 29.07%,4608/15849, Loss=0.401569664478302
Epoch: [0]/[200] Training, 32.30%,5120/15849, Loss=0.39867106080055237
Epoch: [0]/[200] Training, 35.54%,5632/15849, Loss=0.37596726417541504
Epoch: [0]/[200] Training, 38.77%,6144/15849, Loss=0.39209672808647156
Epoch: [0]/[200] Training, 42.00%,6656/15849, Loss=0.40474680066108704
Epoch: [0]/[200] Trai


KeyboardInterrupt



In [14]:
def calc_euclidean(feat1, feat2):
    feat1 = feat1.ravel()
    feat2 = feat2.ravel()
    return torch.dot(feat1, feat2) / (torch.linalg.norm(feat1) * torch.linalg.norm(feat2))
ok = 0
fales = 0
for i in range(200):
    a, p, n = val_dataset.__getitem__(i)
    a = torch.from_numpy(a).float()
    p = torch.from_numpy(p).float()
    n = torch.from_numpy(n).float()

    feat_a = model(a.unsqueeze(0))
    feat_p = model(p.unsqueeze(0))
    feat_n = model(n.unsqueeze(0))

    d_ap = calc_euclidean(feat_a, feat_p).detach().cpu().numpy()
    d_an = calc_euclidean(feat_a, feat_n).detach().cpu().numpy()
    print(d_ap, d_an)
    if d_ap>0.5:
        ok += 1
    else:
        fales += 1
    if d_an<=0.5:
        ok += 1
    else:
        fales += 1

print(ok, fales)


0.93492377 -0.24233335
0.9463702 -0.21094653
0.98746496 -0.085279986
0.8990053 -0.14777455
0.974549 -0.16407453
0.91302896 -0.28488055
0.9742724 -0.123406775
0.8666797 -0.054263745
0.9800681 -0.21377677
0.9667697 -0.23248772
0.723328 -0.37068725
0.90487546 0.9922157
0.9006486 0.03421862
0.97548574 0.9730354
0.96246886 -0.11990808
0.9172875 0.9979038
0.98029006 -0.3672654
0.943522 -0.27532157
0.71167856 -0.17687374
0.9295279 0.76083004
0.9956552 -0.09590244
0.9454024 -0.6619248
0.98020005 -0.10497154
0.9353707 -0.100796945
0.93954754 0.90227723
0.915203 0.984315
0.9427173 0.93906873
0.92574644 0.97098535
0.970981 0.90607727
0.93067354 -0.67093337
0.9843961 -0.06573344
0.9984557 -0.65910107
0.99875265 0.32489637
0.5831846 0.9005014
0.9642543 0.2609868
0.97541034 -0.17695755
0.74179155 0.5975812
0.9437297 -0.5062643
0.983991 0.9244435
0.8639407 -0.06669008
0.97711146 -0.62418044
0.99152243 0.9496718
0.9909818 -0.6139446
0.91971713 -0.180524
0.9169249 -0.061829407
0.69248176 0.882265
0.620