In [1]:
#!/usr/bin/env python

import torch
from torch import nn
from torch import optim
from torch import Tensor
from torch.nn import functional as F
import dlc_practical_prologue as prologue
import matplotlib.pyplot as plt
%matplotlib notebook

In [2]:
#Data generation
N=10**3
train_input,train_target,train_classes,test_input,test_target,test_classes=prologue.generate_pair_sets(N)
train_target=train_target.long()#.float for MSELoss, .long for CrossEntropy
train_input=train_input.float()
train_classes=train_classes.long()

In [119]:
#Base functions adapted from the practicals
def train_model(model, train_input, train_target,train_classes, mini_batch_size, crit=nn.MSELoss, eta = 1e-3, nb_epochs = 250,print_=False):
    criterion = crit()
    optimizer = optim.SGD(model.parameters(), lr = eta)
    for e in range(nb_epochs):
        acc_loss = 0
        acc_loss1 = 0
        acc_loss2 = 0
        acc_loss3 = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            output,aux_output = model(train_input.narrow(0, b, mini_batch_size))
            if crit==nn.MSELoss:
                loss1 = criterion(output[:,1], train_target.narrow(0, b, mini_batch_size))
                #print(torch.argmax(aux_output[:,0:9],dim=1))
                #print(train_classes[:,0].narrow(0, b, mini_batch_size))
                loss2 = criterion(torch.argmax(aux_output[:,0:9],dim=1), train_classes[:,0].narrow(0, b, mini_batch_size))
                loss3 = criterion(torch.argmax(aux_output[:,10:19],dim=1), train_classes[:,1].narrow(0, b, mini_batch_size))
                loss = loss1 + loss2 + loss3
                print('|| loss1 req grad =', loss1.requires_grad, '|| loss2 req grad =',loss2.requires_grad,'|| loss3 req grad =', loss3.requires_grad)
            elif crit==nn.CrossEntropyLoss:
                loss1 = criterion(output, train_target.narrow(0, b, mini_batch_size))
                #print(torch.argmax(aux_output[:,0:9],dim=1))
                #print(train_classes[:,0].narrow(0, b, mini_batch_size))
                loss2 = criterion(aux_output[:,:10], train_classes[:,0].narrow(0, b, mini_batch_size))
                loss3 = criterion(aux_output[:,10:], train_classes[:,1].narrow(0, b, mini_batch_size))
                loss = loss1 #+ 0.1*(loss2 + loss3)
                #print(loss1, loss2.requires_grad, loss3.requires_grad)
            else:
                print("Loss not implemented")
            acc_loss = acc_loss + loss.item()
            acc_loss1 = acc_loss1 + loss1.item()
            acc_loss2 = acc_loss2 + loss2.item()
            acc_loss3 = acc_loss3 + loss3.item()
            model.zero_grad()
            loss.backward()
            optimizer.step()
            if False:
                with torch.no_grad():
                    for p in model.parameters():
                        p -= eta * p.grad

        print(e, 'tot loss', acc_loss, 'loss1', acc_loss1, 'loss2', acc_loss2, 'loss3', acc_loss3)
            
def compute_nb_errors(model, input, target, mini_batch_size=100):
    nb_errors = 0

    for b in range(0, input.size(0), mini_batch_size):
        output , aux_output = model(input.narrow(0, b, mini_batch_size))
        _, predicted_classes = output.max(1)
        for k in range(mini_batch_size):
            if target[b + k]!=predicted_classes[k]:
                nb_errors = nb_errors + 1

    return nb_errors

def run_many_times(model,crit=nn.MSELoss,mini_batch_size=100,n=10,print_=False):
    average_error=0
    for i in range(n):
        m=model()
        train_model(m, train_input, train_target,train_classes,mini_batch_size,crit=crit)
        nb_test_errors = compute_nb_errors(m, test_input, test_target, mini_batch_size)
        print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                      nb_test_errors, test_input.size(0)))
        average_error+=(100 * nb_test_errors) / test_input.size(0)
    print("Average error: "+str(average_error/n))

In [120]:
#Is it better to use groups or not?
#Takes about 2 hours to run
#about 22.5% error average without groups if we exclude outliers that get stuck and don't move
#about 21.5% error average with groups if we exclude outliers that get stuck and don't move
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 100, kernel_size=3,groups=2)
        #self.conv2 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(1600, 20)
        self.fc2 = nn.Linear(20, 2)
        self.aux_linear = nn.Linear(20, 20)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        aux_output = F.softmax(self.fc1(x.view(-1, 1600)), dim=1)
        x = F.relu(self.fc1(x.view(-1, 1600)))
        output = F.softmax(self.fc2(x), dim=1)
        aux_output = F.softmax(x, dim=1)
        #print(x)
        return output, aux_output
    
    def last_hiddes(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 1600)))
        return x

class NetGroups(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3, groups=2)
        #self.conv2 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(512, 20)
        self.fc2 = nn.Linear(20, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 512)))
        x = F.softmax(self.fc2(x), dim=1)
        #print(x)
        return x



In [121]:
m=Net()
mini_batch_size = 100

In [122]:
########################################################################################################################
# If you try to run the train using MSELoss like you were doing, you will get that loss2 and loss3 
# requires_grad = False
#####################################################################################################################

In [123]:
train_target = train_target.float()
train_classes = train_classes.float()
crit = nn.MSELoss

train_model(m, train_input, train_target, train_classes,mini_batch_size, crit, nb_epochs = 1)

|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
0 tot loss 355.62580490112305 loss1 4.455808699131012 loss2 243.2999973297119 loss3 107.86999988555908


In [124]:
###################################################################################################################
# This is because you used a sclar target for the 10 classes, using argmax before calling the loss
# Since argmax is a non differentiable function it sets the req. grad. of the result to False
# Example below:

In [125]:
a = torch.arange(5., requires_grad = True)
b = torch.argmax(a)
a.dtype, a.requires_grad, b.requires_grad

(torch.float32, True, False)

In [126]:
###################################################################################################################
# If you don't want to use hot label embedding you need to use the Cross Entropy instead. 
# I wrote some code in the training function and it seems to be working
##################################################################################################################

In [127]:
train_target = train_target.long()
train_classes = train_classes.long()
crit = nn.CrossEntropyLoss

train_model(m, train_input, train_target, train_classes,mini_batch_size, crit)

0 tot loss 7.083635568618774 loss1 7.083635568618774 loss2 23.50584363937378 loss3 23.36737632751465
1 tot loss 6.5514044761657715 loss1 6.5514044761657715 loss2 23.662835121154785 loss3 23.397489309310913
2 tot loss 6.548925280570984 loss1 6.548925280570984 loss2 23.726492643356323 loss3 23.351585626602173
3 tot loss 6.506751835346222 loss1 6.506751835346222 loss2 23.75475788116455 loss3 23.20123839378357
4 tot loss 6.376391887664795 loss1 6.376391887664795 loss2 23.70777940750122 loss3 23.31951332092285
5 tot loss 6.26804780960083 loss1 6.26804780960083 loss2 23.748292684555054 loss3 23.332770109176636
6 tot loss 6.0114635825157166 loss1 6.0114635825157166 loss2 23.88239622116089 loss3 23.285412073135376
7 tot loss 5.968484044075012 loss1 5.968484044075012 loss2 23.83894991874695 loss3 23.31095242500305
8 tot loss 5.975208044052124 loss1 5.975208044052124 loss2 23.698276042938232 loss3 23.47767210006714
9 tot loss 5.887198984622955 loss1 5.887198984622955 loss2 23.792410612106323 los

80 tot loss 4.432866990566254 loss1 4.432866990566254 loss2 23.47458815574646 loss3 23.32164168357849
81 tot loss 4.276668727397919 loss1 4.276668727397919 loss2 23.566107273101807 loss3 23.324901819229126
82 tot loss 4.64102166891098 loss1 4.64102166891098 loss2 23.529914140701294 loss3 23.309309482574463
83 tot loss 4.436985820531845 loss1 4.436985820531845 loss2 23.55938458442688 loss3 23.336621284484863
84 tot loss 4.5342302322387695 loss1 4.5342302322387695 loss2 23.64218258857727 loss3 23.26150369644165
85 tot loss 4.42035984992981 loss1 4.42035984992981 loss2 23.59520387649536 loss3 23.237106323242188
86 tot loss 4.414541035890579 loss1 4.414541035890579 loss2 23.523736715316772 loss3 23.278398990631104
87 tot loss 4.315794736146927 loss1 4.315794736146927 loss2 23.567427396774292 loss3 23.26348853111267
88 tot loss 4.471444934606552 loss1 4.471444934606552 loss2 23.536428928375244 loss3 23.321644067764282
89 tot loss 4.3768496215343475 loss1 4.3768496215343475 loss2 23.57113909

159 tot loss 3.6513370871543884 loss1 3.6513370871543884 loss2 23.67773723602295 loss3 23.267898082733154
160 tot loss 3.6502725780010223 loss1 3.6502725780010223 loss2 23.677562713623047 loss3 23.27153730392456
161 tot loss 3.6490380465984344 loss1 3.6490380465984344 loss2 23.67766523361206 loss3 23.2720844745636
162 tot loss 3.649061918258667 loss1 3.649061918258667 loss2 23.674086809158325 loss3 23.278908014297485
163 tot loss 3.6484929025173187 loss1 3.6484929025173187 loss2 23.670122146606445 loss3 23.28499436378479
164 tot loss 3.649860829114914 loss1 3.649860829114914 loss2 23.669405221939087 loss3 23.28862953186035
165 tot loss 3.65177184343338 loss1 3.65177184343338 loss2 23.662638425827026 loss3 23.293423414230347
166 tot loss 3.659593552350998 loss1 3.659593552350998 loss2 23.65524935722351 loss3 23.289854049682617
167 tot loss 3.6511134803295135 loss1 3.6511134803295135 loss2 23.667009592056274 loss3 23.292081594467163
168 tot loss 4.109481602907181 loss1 4.109481602907181 

238 tot loss 3.524706870317459 loss1 3.524706870317459 loss2 23.603856086730957 loss3 23.439411878585815
239 tot loss 3.524441570043564 loss1 3.524441570043564 loss2 23.6042218208313 loss3 23.438949584960938
240 tot loss 3.524189442396164 loss1 3.524189442396164 loss2 23.60575008392334 loss3 23.439419507980347
241 tot loss 3.5239596366882324 loss1 3.5239596366882324 loss2 23.60478162765503 loss3 23.43895983695984
242 tot loss 3.523828446865082 loss1 3.523828446865082 loss2 23.60513925552368 loss3 23.43990421295166
243 tot loss 3.5236441493034363 loss1 3.5236441493034363 loss2 23.60492777824402 loss3 23.439867734909058
244 tot loss 3.5234424769878387 loss1 3.5234424769878387 loss2 23.60517954826355 loss3 23.4414701461792
245 tot loss 3.523239940404892 loss1 3.523239940404892 loss2 23.603865385055542 loss3 23.441389083862305
246 tot loss 3.5230972170829773 loss1 3.5230972170829773 loss2 23.603363752365112 loss3 23.441514015197754
247 tot loss 3.522908538579941 loss1 3.522908538579941 los

In [128]:
nb_test_errors = compute_nb_errors(m, test_input, test_target, mini_batch_size = 100)
print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),nb_test_errors, test_input.size(0)))

test error Net 19.00% 190/1000


In [109]:
num=0
for p in m.parameters():
    print(p.shape)
    num+=1
print(num)

torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([20, 512])
torch.Size([20])
torch.Size([2, 20])
torch.Size([2])
torch.Size([20, 20])
torch.Size([20])
8
