In [27]:
#!/usr/bin/env python

import torch
import various_data_functions
from torch import nn
from torch import optim
from torch import Tensor
from torch.nn import functional as F
import dlc_practical_prologue as prologue
import matplotlib.pyplot as plt
%matplotlib notebook

In [38]:
#Data generation
N=10**3
#train_input,train_target,train_classes,test_input,test_target,test_classes=prologue.generate_pair_sets(N)
train_input,train_target,train_classes,test_input,test_target,test_classes=various_data_functions.data(N,True,False,nn.CrossEntropyLoss)
#train_target=train_target.long()#.float for MSELoss, .long for CrossEntropy
#train_input=train_input.float()
#train_classes=train_classes.long()

In [39]:
#Base functions adapted from the practicals
def train_model(model, train_input, train_target,train_classes, mini_batch_size, crit=nn.MSELoss, eta = 1e-3, nb_epochs = 250,print_=False):
    criterion = crit()
    optimizer = optim.SGD(model.parameters(), lr = eta)
    for e in range(nb_epochs):
        acc_loss = 0
        acc_loss1 = 0
        acc_loss2 = 0
        acc_loss3 = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            output,aux_output = model(train_input.narrow(0, b, mini_batch_size))
            if crit==nn.MSELoss:
                loss1 = criterion(output[:,1], train_target.narrow(0, b, mini_batch_size))
                #print(torch.argmax(aux_output[:,0:9],dim=1))
                #print(train_classes[:,0].narrow(0, b, mini_batch_size))
                loss2 = criterion(torch.argmax(aux_output[:,0:9],dim=1), train_classes[:,0].narrow(0, b, mini_batch_size))
                loss3 = criterion(torch.argmax(aux_output[:,10:19],dim=1), train_classes[:,1].narrow(0, b, mini_batch_size))
                loss = loss1 + loss2 + loss3
                print('|| loss1 req grad =', loss1.requires_grad, '|| loss2 req grad =',loss2.requires_grad,'|| loss3 req grad =', loss3.requires_grad)
            elif crit==nn.CrossEntropyLoss:
                loss1 = criterion(output, train_target.narrow(0, b, mini_batch_size))
                #print(torch.argmax(aux_output[:,0:9],dim=1))
                #print(train_classes[:,0].narrow(0, b, mini_batch_size))
                loss2 = criterion(aux_output[:,:10], train_classes[:,0].narrow(0, b, mini_batch_size))
                loss3 = criterion(aux_output[:,10:], train_classes[:,1].narrow(0, b, mini_batch_size))
                loss = loss1 + 0.1*(loss2 + loss3)
                #print(loss1, loss2.requires_grad, loss3.requires_grad)
            else:
                print("Loss not implemented")
            acc_loss = acc_loss + loss.item()
            acc_loss1 = acc_loss1 + loss1.item()
            acc_loss2 = acc_loss2 + loss2.item()
            acc_loss3 = acc_loss3 + loss3.item()
            model.zero_grad()
            loss.backward()
            optimizer.step()
            if False:
                with torch.no_grad():
                    for p in model.parameters():
                        p -= eta * p.grad

        print(e, 'tot loss', acc_loss, 'loss1', acc_loss1, 'loss2', acc_loss2, 'loss3', acc_loss3)
            
def compute_nb_errors(model, input, target, mini_batch_size=100):
    nb_errors = 0

    for b in range(0, input.size(0), mini_batch_size):
        output , aux_output = model(input.narrow(0, b, mini_batch_size))
        _, predicted_classes = output.max(1)
        for k in range(mini_batch_size):
            if target[b + k]!=predicted_classes[k]:
                nb_errors = nb_errors + 1

    return nb_errors

def run_many_times(model,crit=nn.MSELoss,mini_batch_size=100,n=10,print_=False):
    average_error=0
    for i in range(n):
        m=model()
        train_model(m, train_input, train_target,train_classes,mini_batch_size,crit=crit)
        nb_test_errors = compute_nb_errors(m, test_input, test_target, mini_batch_size)
        print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                      nb_test_errors, test_input.size(0)))
        average_error+=(100 * nb_test_errors) / test_input.size(0)
    print("Average error: "+str(average_error/n))

In [40]:
#Is it better to use groups or not?
#Takes about 2 hours to run
#about 22.5% error average without groups if we exclude outliers that get stuck and don't move
#about 21.5% error average with groups if we exclude outliers that get stuck and don't move
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 100, kernel_size=3,groups=2)
        #self.conv2 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(1600, 20)
        self.fc2 = nn.Linear(20, 2)
        self.aux_linear = nn.Linear(20, 20)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        aux_output = F.softmax(self.fc1(x.view(-1, 1600)), dim=1)
        x = F.relu(self.fc1(x.view(-1, 1600)))
        output = F.softmax(self.fc2(x), dim=1)
        aux_output = F.softmax(x, dim=1)
        #print(x)
        return output, aux_output
    
    def last_hiddes(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 1600)))
        return x

class NetGroups(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3, groups=2)
        #self.conv2 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(512, 20)
        self.fc2 = nn.Linear(20, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 512)))
        x = F.softmax(self.fc2(x), dim=1)
        #print(x)
        return x



In [41]:
m=Net()
mini_batch_size = 100

In [42]:
########################################################################################################################
# If you try to run the train using MSELoss like you were doing, you will get that loss2 and loss3 
# requires_grad = False
#####################################################################################################################

In [43]:
train_target = train_target.float()
train_classes = train_classes.float()
crit = nn.MSELoss

train_model(m, train_input, train_target, train_classes,mini_batch_size, crit, nb_epochs = 1)

|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
|| loss1 req grad = True || loss2 req grad = False || loss3 req grad = False
0 tot loss 323.7322196960449 loss1 2.6422192454338074 loss2 121.73000144958496 loss3 199.3599967956543


In [44]:
###################################################################################################################
# This is because you used a sclar target for the 10 classes, using argmax before calling the loss
# Since argmax is a non differentiable function it sets the req. grad. of the result to False
# Example below:

In [45]:
a = torch.arange(5., requires_grad = True)
b = torch.argmax(a)
a.dtype, a.requires_grad, b.requires_grad

(torch.float32, True, False)

In [46]:
###################################################################################################################
# If you don't want to use hot label embedding you need to use the Cross Entropy instead. 
# I wrote some code in the training function and it seems to be working
##################################################################################################################

In [47]:
train_target = train_target.long()
train_classes = train_classes.long()
crit = nn.CrossEntropyLoss

train_model(m, train_input, train_target, train_classes,mini_batch_size, crit)

0 tot loss 7.016153454780579 loss1 7.016153454780579 loss2 23.03327965736389 loss3 23.031555652618408
1 tot loss 6.994461953639984 loss1 6.994461953639984 loss2 23.03477692604065 loss3 23.032089710235596
2 tot loss 6.9755819439888 loss1 6.9755819439888 loss2 23.036263942718506 loss3 23.03273367881775
3 tot loss 6.957402348518372 loss1 6.957402348518372 loss2 23.03774857521057 loss3 23.03362727165222
4 tot loss 6.939821124076843 loss1 6.939821124076843 loss2 23.03927254676819 loss3 23.034764289855957
5 tot loss 6.922888100147247 loss1 6.922888100147247 loss2 23.040833950042725 loss3 23.03585720062256
6 tot loss 6.90782630443573 loss1 6.90782630443573 loss2 23.042372703552246 loss3 23.036945104599
7 tot loss 6.893777668476105 loss1 6.893777668476105 loss2 23.043835163116455 loss3 23.038168907165527
8 tot loss 6.880793809890747 loss1 6.880793809890747 loss2 23.045262336730957 loss3 23.039315700531006
9 tot loss 6.869109869003296 loss1 6.869109869003296 loss2 23.04661536216736 loss3 23.040

80 tot loss 6.554215610027313 loss1 6.554215610027313 loss2 23.075857162475586 loss3 23.066778898239136
81 tot loss 6.549827814102173 loss1 6.549827814102173 loss2 23.076112270355225 loss3 23.066948175430298
82 tot loss 6.545384585857391 loss1 6.545384585857391 loss2 23.076367616653442 loss3 23.067119598388672
83 tot loss 6.54087507724762 loss1 6.54087507724762 loss2 23.076631546020508 loss3 23.067291259765625
84 tot loss 6.53634512424469 loss1 6.53634512424469 loss2 23.076889991760254 loss3 23.067465782165527
85 tot loss 6.531810283660889 loss1 6.531810283660889 loss2 23.07711672782898 loss3 23.067644119262695
86 tot loss 6.5272422432899475 loss1 6.5272422432899475 loss2 23.077343702316284 loss3 23.06782555580139
87 tot loss 6.522688806056976 loss1 6.522688806056976 loss2 23.077577352523804 loss3 23.068009614944458
88 tot loss 6.518051445484161 loss1 6.518051445484161 loss2 23.077810764312744 loss3 23.06818652153015
89 tot loss 6.513375520706177 loss1 6.513375520706177 loss2 23.078078

159 tot loss 6.185571908950806 loss1 6.185571908950806 loss2 23.094122171401978 loss3 23.08172631263733
160 tot loss 6.1810959577560425 loss1 6.1810959577560425 loss2 23.094354391098022 loss3 23.081928730010986
161 tot loss 6.176643252372742 loss1 6.176643252372742 loss2 23.09458565711975 loss3 23.082136631011963
162 tot loss 6.172203600406647 loss1 6.172203600406647 loss2 23.09481430053711 loss3 23.08233642578125
163 tot loss 6.167774438858032 loss1 6.167774438858032 loss2 23.095046997070312 loss3 23.082534790039062
164 tot loss 6.163355886936188 loss1 6.163355886936188 loss2 23.0952787399292 loss3 23.082732677459717
165 tot loss 6.158950865268707 loss1 6.158950865268707 loss2 23.095506191253662 loss3 23.082929849624634
166 tot loss 6.154558300971985 loss1 6.154558300971985 loss2 23.095738172531128 loss3 23.083128213882446
167 tot loss 6.15018230676651 loss1 6.15018230676651 loss2 23.095962285995483 loss3 23.083324432373047
168 tot loss 6.145814597606659 loss1 6.145814597606659 loss2 

238 tot loss 5.879753649234772 loss1 5.879753649234772 loss2 23.10991358757019 loss3 23.096988201141357
239 tot loss 5.8765029311180115 loss1 5.8765029311180115 loss2 23.11006999015808 loss3 23.097179651260376
240 tot loss 5.8732717633247375 loss1 5.8732717633247375 loss2 23.110225439071655 loss3 23.097376585006714
241 tot loss 5.870048940181732 loss1 5.870048940181732 loss2 23.11038112640381 loss3 23.09757399559021
242 tot loss 5.866842448711395 loss1 5.866842448711395 loss2 23.11053705215454 loss3 23.097773790359497
243 tot loss 5.863650918006897 loss1 5.863650918006897 loss2 23.110686779022217 loss3 23.097971439361572
244 tot loss 5.860475480556488 loss1 5.860475480556488 loss2 23.110838174819946 loss3 23.09816598892212
245 tot loss 5.857313454151154 loss1 5.857313454151154 loss2 23.110985279083252 loss3 23.09836435317993
246 tot loss 5.854165434837341 loss1 5.854165434837341 loss2 23.11113715171814 loss3 23.09856104850769
247 tot loss 5.851036250591278 loss1 5.851036250591278 loss2

In [48]:
nb_test_errors = compute_nb_errors(m, test_input, test_target, mini_batch_size = 100)
print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),nb_test_errors, test_input.size(0)))

test error Net 23.00% 230/1000


In [109]:
num=0
for p in m.parameters():
    print(p.shape)
    num+=1
print(num)

torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([20, 512])
torch.Size([20])
torch.Size([2, 20])
torch.Size([2])
torch.Size([20, 20])
torch.Size([20])
8
