In [43]:
from cls_models import ClsUnseenTrain
from generate import load_seen_att
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
from mmdetection.splits import get_seen_class_ids


In [44]:
%psource ClsUnseenTrain.forward

In [45]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [46]:
opt = dotdict({
    'dataset':'coco',
    'classes_split': '65_15',
    'class_embedding': 'MSCOCO/fasttext.npy',
    'dataroot':'../../data/coco',
    'trainsplit': 'train_0.6_0.3',
    
})

In [47]:
seen_att, att_labels = load_seen_att(opt)
classid_tolabels = {l:i for i, l in enumerate(att_labels.data.numpy())}

In [48]:
unseen_classifier = ClsUnseenTrain(seen_att).cuda()


__init__ torch.Size([66, 300])


In [9]:
seen_features = np.load(f"{opt.dataroot}/{opt.trainsplit}_feats.npy")
seen_labels = np.load(f"{opt.dataroot}/{opt.trainsplit}_labels.npy")


In [55]:
inds = np.random.permutation(np.arange(len(seen_labels)))
total_train_examples = int (0.8 * len(seen_labels))
train_inds = inds[:total_train_examples]
test_inds = inds[total_train_examples:]

In [56]:
len(test_inds)+len(train_inds), len(seen_labels)

(8846001, 8846001)

In [57]:
train_feats = seen_features[train_inds]
train_labels = seen_labels[train_inds]
test_feats = seen_features[test_inds]
test_labels = seen_labels[test_inds]

In [37]:
# bg_inds = np.where(seen_labels==0)
# fg_inds = np.where(seen_labels>0)

In [58]:
class Featuresdataset(Dataset):
     
    def __init__(self, features, labels, classid_tolabels):
        self.classid_tolabels = classid_tolabels
        self.features = features
        self.labels = labels
        

    def __getitem__(self, idx):
        batch_feature = self.features[idx]
        batch_label = self.labels[idx]
#         import pdb; pdb.set_trace()
        
        if self.classid_tolabels is not None:
            batch_label = self.classid_tolabels[batch_label]
        return batch_feature, batch_label

    def __len__(self):
        return len(self.labels)

In [50]:
seen_labels.shape

(8846001,)

In [109]:

dataset_train = Featuresdataset(train_feats, train_labels, classid_tolabels)
dataloader_train = DataLoader(dataset_train, batch_size=512, shuffle=True) 
dataset_test = Featuresdataset(test_feats, test_labels, classid_tolabels)
dataloader_test = DataLoader(dataset_test, batch_size=1024, shuffle=True) 

In [110]:
from torch.optim.lr_scheduler import StepLR

criterion = nn.CrossEntropyLoss()

# criterion = nn.NLLLoss()

optimizer = optim.SGD(unseen_classifier.parameters(), lr=1, momentum=0.9)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)


In [105]:
min_val_loss = float("inf")

In [106]:
path = 'MSCOCO/unseen_Classifier.pth'

In [111]:
def val():
    running_loss = 0.0
    global min_val_loss
    unseen_classifier.eval()
    for i, (inputs, labels) in enumerate(dataloader_test, 0):
        inputs = inputs.cuda()
        labels = labels.cuda()
        

        outputs = unseen_classifier(inputs)
        loss = criterion(outputs, labels)

        running_loss += loss.item()
        if i % 200 == 199:
            print(f'Test Loss {epoch + 1}, [{i + 1} / {len(dataloader_test)}], {(running_loss / i) :0.4f}')
    if (running_loss / i) < min_val_loss:
        min_val_loss = running_loss / i
        state_dict = unseen_classifier.state_dict()   
        torch.save(state_dict, path)
        print(f'saved {min_val_loss :0.4f}')

In [112]:
for epoch in range(100):  # loop over the dataset multiple times
    unseen_classifier.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(dataloader_train, 0):
        inputs = inputs.cuda()
        labels = labels.cuda()
        
        optimizer.zero_grad()

        outputs = unseen_classifier(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 1000 == 999:    # print every 2000 mini-batches
            print(f'Train Loss {epoch + 1}, [{i + 1} / {len(dataloader_train)}], {(running_loss / i) :0.4f}')
    val()
    scheduler.step()
    
print('Finished Training')

Train Loss 1, [1000 / 13822], 0.1713
Train Loss 1, [2000 / 13822], 0.1707
Train Loss 1, [3000 / 13822], 0.1710
Train Loss 1, [4000 / 13822], 0.1707
Train Loss 1, [5000 / 13822], 0.1703
Train Loss 1, [6000 / 13822], 0.1704
Train Loss 1, [7000 / 13822], 0.1704
Train Loss 1, [8000 / 13822], 0.1704
Train Loss 1, [9000 / 13822], 0.1704
Train Loss 1, [10000 / 13822], 0.1702
Train Loss 1, [11000 / 13822], 0.1702
Train Loss 1, [12000 / 13822], 0.1703
Train Loss 1, [13000 / 13822], 0.1702
Test Loss 1, [200 / 1728], 0.1735
Test Loss 1, [400 / 1728], 0.1735
Test Loss 1, [600 / 1728], 0.1735
Test Loss 1, [800 / 1728], 0.1738
Test Loss 1, [1000 / 1728], 0.1738
Test Loss 1, [1200 / 1728], 0.1741
Test Loss 1, [1400 / 1728], 0.1742
Test Loss 1, [1600 / 1728], 0.1742
saved 0.1741
Train Loss 2, [1000 / 13822], 0.1682
Train Loss 2, [2000 / 13822], 0.1688
Train Loss 2, [3000 / 13822], 0.1692
Train Loss 2, [4000 / 13822], 0.1688
Train Loss 2, [5000 / 13822], 0.1691
Train Loss 2, [6000 / 13822], 0.1692
Trai

Test Loss 11, [400 / 1728], 0.1742
Test Loss 11, [600 / 1728], 0.1742
Test Loss 11, [800 / 1728], 0.1742
Test Loss 11, [1000 / 1728], 0.1740
Test Loss 11, [1200 / 1728], 0.1740
Test Loss 11, [1400 / 1728], 0.1741
Test Loss 11, [1600 / 1728], 0.1743
Train Loss 12, [1000 / 13822], 0.1668
Train Loss 12, [2000 / 13822], 0.1668
Train Loss 12, [3000 / 13822], 0.1666
Train Loss 12, [4000 / 13822], 0.1672
Train Loss 12, [5000 / 13822], 0.1678
Train Loss 12, [6000 / 13822], 0.1680
Train Loss 12, [7000 / 13822], 0.1680
Train Loss 12, [8000 / 13822], 0.1681
Train Loss 12, [9000 / 13822], 0.1683
Train Loss 12, [10000 / 13822], 0.1684
Train Loss 12, [11000 / 13822], 0.1685
Train Loss 12, [12000 / 13822], 0.1684
Train Loss 12, [13000 / 13822], 0.1685
Test Loss 12, [200 / 1728], 0.1711
Test Loss 12, [400 / 1728], 0.1707
Test Loss 12, [600 / 1728], 0.1711
Test Loss 12, [800 / 1728], 0.1711
Test Loss 12, [1000 / 1728], 0.1711
Test Loss 12, [1200 / 1728], 0.1712
Test Loss 12, [1400 / 1728], 0.1715
Test 

Train Loss 22, [4000 / 13822], 0.1674
Train Loss 22, [5000 / 13822], 0.1673
Train Loss 22, [6000 / 13822], 0.1674
Train Loss 22, [7000 / 13822], 0.1676
Train Loss 22, [8000 / 13822], 0.1677
Train Loss 22, [9000 / 13822], 0.1679
Train Loss 22, [10000 / 13822], 0.1679
Train Loss 22, [11000 / 13822], 0.1680
Train Loss 22, [12000 / 13822], 0.1681
Train Loss 22, [13000 / 13822], 0.1682
Test Loss 22, [200 / 1728], 0.1749
Test Loss 22, [400 / 1728], 0.1743
Test Loss 22, [600 / 1728], 0.1750
Test Loss 22, [800 / 1728], 0.1745
Test Loss 22, [1000 / 1728], 0.1740
Test Loss 22, [1200 / 1728], 0.1743
Test Loss 22, [1400 / 1728], 0.1743
Test Loss 22, [1600 / 1728], 0.1743
Train Loss 23, [1000 / 13822], 0.1675
Train Loss 23, [2000 / 13822], 0.1675
Train Loss 23, [3000 / 13822], 0.1676
Train Loss 23, [4000 / 13822], 0.1678
Train Loss 23, [5000 / 13822], 0.1679
Train Loss 23, [6000 / 13822], 0.1678
Train Loss 23, [7000 / 13822], 0.1678
Train Loss 23, [8000 / 13822], 0.1678
Train Loss 23, [9000 / 13822

Test Loss 32, [200 / 1728], 0.1707
Test Loss 32, [400 / 1728], 0.1697
Test Loss 32, [600 / 1728], 0.1688
Test Loss 32, [800 / 1728], 0.1684
Test Loss 32, [1000 / 1728], 0.1679
Test Loss 32, [1200 / 1728], 0.1677
Test Loss 32, [1400 / 1728], 0.1677
Test Loss 32, [1600 / 1728], 0.1679
saved 0.1677
Train Loss 33, [1000 / 13822], 0.1616
Train Loss 33, [2000 / 13822], 0.1605
Train Loss 33, [3000 / 13822], 0.1615
Train Loss 33, [4000 / 13822], 0.1616
Train Loss 33, [5000 / 13822], 0.1616
Train Loss 33, [6000 / 13822], 0.1615
Train Loss 33, [7000 / 13822], 0.1613
Train Loss 33, [8000 / 13822], 0.1614
Train Loss 33, [9000 / 13822], 0.1615
Train Loss 33, [10000 / 13822], 0.1614
Train Loss 33, [11000 / 13822], 0.1614
Train Loss 33, [12000 / 13822], 0.1615
Train Loss 33, [13000 / 13822], 0.1615
Test Loss 33, [200 / 1728], 0.1710
Test Loss 33, [400 / 1728], 0.1702
Test Loss 33, [600 / 1728], 0.1687
Test Loss 33, [800 / 1728], 0.1683
Test Loss 33, [1000 / 1728], 0.1678
Test Loss 33, [1200 / 1728], 

Train Loss 43, [2000 / 13822], 0.1613
Train Loss 43, [3000 / 13822], 0.1612
Train Loss 43, [4000 / 13822], 0.1611
Train Loss 43, [5000 / 13822], 0.1613
Train Loss 43, [6000 / 13822], 0.1614
Train Loss 43, [7000 / 13822], 0.1614
Train Loss 43, [8000 / 13822], 0.1614
Train Loss 43, [9000 / 13822], 0.1613
Train Loss 43, [10000 / 13822], 0.1614
Train Loss 43, [11000 / 13822], 0.1613
Train Loss 43, [12000 / 13822], 0.1614
Train Loss 43, [13000 / 13822], 0.1613
Test Loss 43, [200 / 1728], 0.1683
Test Loss 43, [400 / 1728], 0.1675
Test Loss 43, [600 / 1728], 0.1674
Test Loss 43, [800 / 1728], 0.1675
Test Loss 43, [1000 / 1728], 0.1676
Test Loss 43, [1200 / 1728], 0.1674
Test Loss 43, [1400 / 1728], 0.1676
Test Loss 43, [1600 / 1728], 0.1676
Train Loss 44, [1000 / 13822], 0.1613
Train Loss 44, [2000 / 13822], 0.1607
Train Loss 44, [3000 / 13822], 0.1608
Train Loss 44, [4000 / 13822], 0.1610
Train Loss 44, [5000 / 13822], 0.1611
Train Loss 44, [6000 / 13822], 0.1612
Train Loss 44, [7000 / 13822

Train Loss 53, [12000 / 13822], 0.1612
Train Loss 53, [13000 / 13822], 0.1612
Test Loss 53, [200 / 1728], 0.1650
Test Loss 53, [400 / 1728], 0.1661
Test Loss 53, [600 / 1728], 0.1671
Test Loss 53, [800 / 1728], 0.1673
Test Loss 53, [1000 / 1728], 0.1678
Test Loss 53, [1200 / 1728], 0.1678
Test Loss 53, [1400 / 1728], 0.1675
Test Loss 53, [1600 / 1728], 0.1674
saved 0.1675
Train Loss 54, [1000 / 13822], 0.1600
Train Loss 54, [2000 / 13822], 0.1605
Train Loss 54, [3000 / 13822], 0.1608
Train Loss 54, [4000 / 13822], 0.1606
Train Loss 54, [5000 / 13822], 0.1607
Train Loss 54, [6000 / 13822], 0.1606
Train Loss 54, [7000 / 13822], 0.1608
Train Loss 54, [8000 / 13822], 0.1609
Train Loss 54, [9000 / 13822], 0.1610
Train Loss 54, [10000 / 13822], 0.1612
Train Loss 54, [11000 / 13822], 0.1611
Train Loss 54, [12000 / 13822], 0.1611
Train Loss 54, [13000 / 13822], 0.1612
Test Loss 54, [200 / 1728], 0.1704
Test Loss 54, [400 / 1728], 0.1680
Test Loss 54, [600 / 1728], 0.1679
Test Loss 54, [800 / 1

Test Loss 63, [1600 / 1728], 0.1672
saved 0.1671
Train Loss 64, [1000 / 13822], 0.1604
Train Loss 64, [2000 / 13822], 0.1605
Train Loss 64, [3000 / 13822], 0.1605
Train Loss 64, [4000 / 13822], 0.1605
Train Loss 64, [5000 / 13822], 0.1604
Train Loss 64, [6000 / 13822], 0.1606
Train Loss 64, [7000 / 13822], 0.1607
Train Loss 64, [8000 / 13822], 0.1606
Train Loss 64, [9000 / 13822], 0.1606
Train Loss 64, [10000 / 13822], 0.1606
Train Loss 64, [11000 / 13822], 0.1606
Train Loss 64, [12000 / 13822], 0.1606
Train Loss 64, [13000 / 13822], 0.1607
Test Loss 64, [200 / 1728], 0.1689
Test Loss 64, [400 / 1728], 0.1677
Test Loss 64, [600 / 1728], 0.1679
Test Loss 64, [800 / 1728], 0.1679
Test Loss 64, [1000 / 1728], 0.1677
Test Loss 64, [1200 / 1728], 0.1673
Test Loss 64, [1400 / 1728], 0.1669
Test Loss 64, [1600 / 1728], 0.1671
Train Loss 65, [1000 / 13822], 0.1610
Train Loss 65, [2000 / 13822], 0.1605
Train Loss 65, [3000 / 13822], 0.1605
Train Loss 65, [4000 / 13822], 0.1605
Train Loss 65, [5

Train Loss 74, [10000 / 13822], 0.1605
Train Loss 74, [11000 / 13822], 0.1604
Train Loss 74, [12000 / 13822], 0.1604
Train Loss 74, [13000 / 13822], 0.1605
Test Loss 74, [200 / 1728], 0.1682
Test Loss 74, [400 / 1728], 0.1687
Test Loss 74, [600 / 1728], 0.1678
Test Loss 74, [800 / 1728], 0.1672
Test Loss 74, [1000 / 1728], 0.1670
Test Loss 74, [1200 / 1728], 0.1667
Test Loss 74, [1400 / 1728], 0.1671
Test Loss 74, [1600 / 1728], 0.1673
Train Loss 75, [1000 / 13822], 0.1621
Train Loss 75, [2000 / 13822], 0.1610
Train Loss 75, [3000 / 13822], 0.1611
Train Loss 75, [4000 / 13822], 0.1610
Train Loss 75, [5000 / 13822], 0.1611
Train Loss 75, [6000 / 13822], 0.1609
Train Loss 75, [7000 / 13822], 0.1609
Train Loss 75, [8000 / 13822], 0.1607
Train Loss 75, [9000 / 13822], 0.1607
Train Loss 75, [10000 / 13822], 0.1607
Train Loss 75, [11000 / 13822], 0.1607
Train Loss 75, [12000 / 13822], 0.1608
Train Loss 75, [13000 / 13822], 0.1607
Test Loss 75, [200 / 1728], 0.1690
Test Loss 75, [400 / 1728],

Test Loss 84, [1400 / 1728], 0.1674
Test Loss 84, [1600 / 1728], 0.1672
Train Loss 85, [1000 / 13822], 0.1597
Train Loss 85, [2000 / 13822], 0.1593
Train Loss 85, [3000 / 13822], 0.1593
Train Loss 85, [4000 / 13822], 0.1595
Train Loss 85, [5000 / 13822], 0.1599
Train Loss 85, [6000 / 13822], 0.1600
Train Loss 85, [7000 / 13822], 0.1604
Train Loss 85, [8000 / 13822], 0.1604
Train Loss 85, [9000 / 13822], 0.1606
Train Loss 85, [10000 / 13822], 0.1606
Train Loss 85, [11000 / 13822], 0.1606
Train Loss 85, [12000 / 13822], 0.1605
Train Loss 85, [13000 / 13822], 0.1606
Test Loss 85, [200 / 1728], 0.1671
Test Loss 85, [400 / 1728], 0.1673
Test Loss 85, [600 / 1728], 0.1667
Test Loss 85, [800 / 1728], 0.1667
Test Loss 85, [1000 / 1728], 0.1666
Test Loss 85, [1200 / 1728], 0.1667
Test Loss 85, [1400 / 1728], 0.1671
Test Loss 85, [1600 / 1728], 0.1671
Train Loss 86, [1000 / 13822], 0.1600
Train Loss 86, [2000 / 13822], 0.1606
Train Loss 86, [3000 / 13822], 0.1604
Train Loss 86, [4000 / 13822], 0

Train Loss 95, [9000 / 13822], 0.1605
Train Loss 95, [10000 / 13822], 0.1605
Train Loss 95, [11000 / 13822], 0.1605
Train Loss 95, [12000 / 13822], 0.1606
Train Loss 95, [13000 / 13822], 0.1606
Test Loss 95, [200 / 1728], 0.1665
Test Loss 95, [400 / 1728], 0.1666
Test Loss 95, [600 / 1728], 0.1665
Test Loss 95, [800 / 1728], 0.1660
Test Loss 95, [1000 / 1728], 0.1662
Test Loss 95, [1200 / 1728], 0.1665
Test Loss 95, [1400 / 1728], 0.1662
Test Loss 95, [1600 / 1728], 0.1670
saved 0.1671
Train Loss 96, [1000 / 13822], 0.1610
Train Loss 96, [2000 / 13822], 0.1607
Train Loss 96, [3000 / 13822], 0.1607
Train Loss 96, [4000 / 13822], 0.1609
Train Loss 96, [5000 / 13822], 0.1609
Train Loss 96, [6000 / 13822], 0.1607
Train Loss 96, [7000 / 13822], 0.1604
Train Loss 96, [8000 / 13822], 0.1602
Train Loss 96, [9000 / 13822], 0.1604
Train Loss 96, [10000 / 13822], 0.1603
Train Loss 96, [11000 / 13822], 0.1604
Train Loss 96, [12000 / 13822], 0.1604
Train Loss 96, [13000 / 13822], 0.1605
Test Loss 9