In [1]:
#!/usr/bin/env python

import torch
from torch import nn
from torch import optim
from torch import Tensor
from torch.nn import functional as F
import dlc_practical_prologue as prologue
import matplotlib.pyplot as plt
%matplotlib notebook

In [2]:
#Data generation
N=10**3
train_input,train_target,train_classes,test_input,test_target,test_classes=prologue.generate_pair_sets(N)
train_target=train_target.long()#.float for MSELoss, .long for CrossEntropy
train_input=train_input.float()
train_classes=train_classes.long()

In [3]:
#Base functions adapted from the practicals
def train_model(model, train_input, train_target,train_classes, mini_batch_size, crit=nn.MSELoss, eta = 1e-3, nb_epochs = 200,print_=False):
    criterion = crit()
    optimizer = optim.SGD(model.parameters(), lr = eta)
    for e in range(nb_epochs):
        acc_loss = 0
        acc_loss1 = 0
        acc_loss2 = 0
        acc_loss3 = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            output,aux_output = model(train_input.narrow(0, b, mini_batch_size))
            if crit==nn.MSELoss:
                loss1 = criterion(output[:,1], train_target.narrow(0, b, mini_batch_size))
                #print(torch.argmax(aux_output[:,0:9],dim=1))
                #print(train_classes[:,0].narrow(0, b, mini_batch_size))
                loss2 = criterion(torch.argmax(aux_output[:,0:9],dim=1), train_classes[:,0].narrow(0, b, mini_batch_size))
                loss3 = criterion(torch.argmax(aux_output[:,10:19],dim=1), train_classes[:,1].narrow(0, b, mini_batch_size))
                loss = loss1 + loss2 + loss3
                print('|| loss1 req grad =', loss1.requires_grad, '|| loss2 req grad =',loss2.requires_grad,'|| loss3 req grad =', loss3.requires_grad)
            elif crit==nn.CrossEntropyLoss:
                loss1 = criterion(output, train_target.narrow(0, b, mini_batch_size))
                #print(torch.argmax(aux_output[:,0:9],dim=1))
                #print(train_classes[:,0].narrow(0, b, mini_batch_size))
                loss2 = criterion(aux_output[:,:10], train_classes[:,0].narrow(0, b, mini_batch_size))
                loss3 = criterion(aux_output[:,10:], train_classes[:,1].narrow(0, b, mini_batch_size))
                loss = loss1 + loss2 + loss3
                #print(loss1, loss2.requires_grad, loss3.requires_grad)
            else:
                print("Loss not implemented")
            acc_loss = acc_loss + loss.item()
            acc_loss1 = acc_loss1 + loss1.item()
            acc_loss2 = acc_loss2 + loss2.item()
            acc_loss3 = acc_loss3 + loss3.item()
            model.zero_grad()
            loss.backward()
            optimizer.step()
            if False:
                with torch.no_grad():
                    for p in model.parameters():
                        p -= eta * p.grad

        print(e, 'tot loss', acc_loss, 'loss1', acc_loss1, 'loss2', acc_loss2, 'loss3', acc_loss3)
            
def compute_nb_errors(model, input, target, mini_batch_size=100):
    nb_errors = 0

    for b in range(0, input.size(0), mini_batch_size):
        output , aux_output = model(input.narrow(0, b, mini_batch_size))
        _, predicted_classes = output.max(1)
        for k in range(mini_batch_size):
            if target[b + k]!=predicted_classes[k]:
                nb_errors = nb_errors + 1

    return nb_errors

def run_many_times(model,crit=nn.MSELoss,mini_batch_size=100,n=10,print_=False):
    average_error=0
    for i in range(n):
        m=model()
        train_model(m, train_input, train_target,train_classes,mini_batch_size,crit=crit)
        nb_test_errors = compute_nb_errors(m, test_input, test_target, mini_batch_size)
        print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                      nb_test_errors, test_input.size(0)))
        average_error+=(100 * nb_test_errors) / test_input.size(0)
    print("Average error: "+str(average_error/n))

In [4]:
#Is it better to use groups or not?
#Takes about 2 hours to run
#about 22.5% error average without groups if we exclude outliers that get stuck and don't move
#about 21.5% error average with groups if we exclude outliers that get stuck and don't move
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3)
        #self.conv2 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(512, 20)
        self.fc2 = nn.Linear(20, 2)
        self.aux_linear = nn.Linear(20, 20)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        aux_output = F.softmax(self.fc1(x.view(-1, 512)), dim=1)
        x = F.relu(self.fc1(x.view(-1, 512)))
        output = F.softmax(self.fc2(x), dim=1)
        aux_output = F.softmax(self.aux_linear(x), dim=1)
        #print(x)
        return output, aux_output
    
    def last_hiddes(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 512)))
        return x

class NetGroups(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3, groups=2)
        #self.conv2 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(512, 20)
        self.fc2 = nn.Linear(20, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 512)))
        x = F.softmax(self.fc2(x), dim=1)
        #print(x)
        return x



In [5]:
m=Net()
mini_batch_size = 100

In [6]:
########################################################################################################################
# If you try to run the train using MSELoss like you were doing, you will get that loss2 and loss3 
# requires_grad = False
#####################################################################################################################

In [7]:
train_target = train_target.float()
train_classes = train_classes.float()
crit = nn.MSELoss

train_model(m, train_input, train_target, train_classes,mini_batch_size, crit, nb_epochs = 1)

|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
0 tot loss 275.68832206726074 loss1 5.4483237862586975 loss2 138.38000011444092 loss3 131.85999870300293


In [8]:
###################################################################################################################
# This is because you used a sclar target for the 10 classes, using argmax before calling the loss
# Since argmax is a non differentiable function it sets the req. grad. of the result to False
# Example below:

In [9]:
a = torch.arange(5., requires_grad = True)
b = torch.argmax(a)
a.dtype, a.requires_grad, b.requires_grad

(torch.float32, True, False)

In [10]:
###################################################################################################################
# If you don't want to use hot label embedding you need to use the Cross Entropy instead. 
# I wrote some code in the training function and it seems to be working
##################################################################################################################

In [11]:
train_target = train_target.long()
train_classes = train_classes.long()
crit = nn.CrossEntropyLoss

train_model(m, train_input, train_target, train_classes,mini_batch_size, crit)

0 tot loss 54.7694149017334 loss1 8.514983236789703 loss2 22.990002632141113 loss3 23.26442861557007
1 tot loss 54.179779052734375 loss1 8.403752028942108 loss2 22.62832283973694 loss3 23.147704601287842
2 tot loss 53.54970359802246 loss1 7.923763692378998 loss2 22.581851720809937 loss3 23.044088125228882
3 tot loss 53.291667461395264 loss1 7.8056557178497314 loss2 22.58093237876892 loss3 22.905078649520874
4 tot loss 52.84849691390991 loss1 7.528869867324829 loss2 22.470181465148926 loss3 22.849446773529053
5 tot loss 52.62556171417236 loss1 7.361152231693268 loss2 22.444522380828857 loss3 22.819887399673462
6 tot loss 52.23776149749756 loss1 7.003297507762909 loss2 22.42754077911377 loss3 22.806923389434814
7 tot loss 52.029207706451416 loss1 6.863841116428375 loss2 22.410935163497925 loss3 22.75443124771118
8 tot loss 51.788204193115234 loss1 6.646120131015778 loss2 22.404370307922363 loss3 22.73771333694458
9 tot loss 51.44192457199097 loss1 6.365801751613617 loss2 22.3718464374542

80 tot loss 48.09533357620239 loss1 4.565479606389999 loss2 22.184742212295532 loss3 21.34511160850525
81 tot loss 48.107683181762695 loss1 4.576798588037491 loss2 22.186275959014893 loss3 21.344608545303345
82 tot loss 48.07877731323242 loss1 4.556395351886749 loss2 22.185813426971436 loss3 21.336567878723145
83 tot loss 47.97042751312256 loss1 4.465198248624802 loss2 22.18732738494873 loss3 21.317902088165283
84 tot loss 48.04565382003784 loss1 4.543907970190048 loss2 22.20192575454712 loss3 21.299819469451904
85 tot loss 47.92909479141235 loss1 4.438454806804657 loss2 22.199656009674072 loss3 21.29098391532898
86 tot loss 48.27120113372803 loss1 4.77182000875473 loss2 22.209484100341797 loss3 21.289896726608276
87 tot loss 48.10299205780029 loss1 4.611643373966217 loss2 22.18812870979309 loss3 21.303220748901367
88 tot loss 47.94213151931763 loss1 4.470024019479752 loss2 22.18841004371643 loss3 21.283697366714478
89 tot loss 48.14700222015381 loss1 4.673428565263748 loss2 22.2070970

159 tot loss 47.16296148300171 loss1 4.035487353801727 loss2 22.225329160690308 loss3 20.90214490890503
160 tot loss 47.17938470840454 loss1 4.052687704563141 loss2 22.227431058883667 loss3 20.899266242980957
161 tot loss 47.18211793899536 loss1 4.058861404657364 loss2 22.225942134857178 loss3 20.89731478691101
162 tot loss 47.15498876571655 loss1 4.036004066467285 loss2 22.22473978996277 loss3 20.8942449092865
163 tot loss 47.128056049346924 loss1 4.009431838989258 loss2 22.224837064743042 loss3 20.893786907196045
164 tot loss 47.11588764190674 loss1 4.000841826200485 loss2 22.224166870117188 loss3 20.890878200531006
165 tot loss 47.10334777832031 loss1 3.9911212027072906 loss2 22.223888635635376 loss3 20.8883376121521
166 tot loss 47.096848487854004 loss1 3.98691788315773 loss2 22.224177837371826 loss3 20.88575267791748
167 tot loss 47.08337640762329 loss1 3.977357655763626 loss2 22.221866369247437 loss3 20.884153842926025
168 tot loss 47.07742977142334 loss1 3.9735409915447235 loss2

In [12]:
nb_test_errors = compute_nb_errors(m, test_input, test_target, mini_batch_size = 100)
print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),nb_test_errors, test_input.size(0)))

test error Net 25.30% 253/1000
