In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
from google.colab import drive
drive.mount("/content/drive/")

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive/


In [4]:
def binarize(x):
    return x.sign()

In [5]:
class BinConv2d(nn.Conv2d):
    def __init__(self, bin_fn, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bin_fn = bin_fn

    def forward(self, x):
        x = binarize(x)
        if self.bin_fn == binarize:
            self.weight = nn.Parameter(self.bin_fn(self.weight))
        else:
            self.bin_fn = Mapping(x.shape[1]).to(device)
            self.weight = nn.Parameter(self.bin_fn(self.weight))
        out = nn.functional.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
        return out

In [6]:
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = nn.ReLU()(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.inflate = 4
        self.in_planes = 16 * self.inflate

        self.conv1 = nn.Conv2d(3, 16 * self.inflate, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16 * self.inflate)
        self.res1 = self._make_layer(block, 16 * self.inflate, num_blocks[0], stride=1)
        self.res2 = self._make_layer(block, 32 * self.inflate, num_blocks[1], stride=2)
        self.res3 = self._make_layer(block, 64 * self.inflate, num_blocks[2], stride=2)
        self.res_num = len(num_blocks)
        if self.res_num==4:
            self.res4 = self._make_layer(block, 128*self.inflate, num_blocks[3], stride=2)
            self.linear = nn.Linear(128 * self.inflate, num_classes)
        else:
            self.linear = nn.Linear(64 * self.inflate, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.res1(out)
        out = self.res2(out)
        out = self.res3(out)
        if self.res_num==4:
            out = self.res4(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return nn.Softmax(dim = -1)(out)

In [7]:
class BinBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, bin_fn = "map"):
        super(BinBlock, self).__init__()
        self.conv1 = BinConv2d(bin_fn, in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = BinConv2d(bin_fn, planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.bin_fn = bin_fn

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = nn.Hardtanh()(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = nn.Hardtanh()(out)
        return out

class Mapping(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, 2 * channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(2 * channels)
        self.conv2 = nn.Conv2d(2 * channels, 2 * channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(2 * channels)
        self.conv3 = nn.Conv2d(2 * channels, channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(self.bn1(out))
        out = self.conv2(out)
        out = self.relu(self.bn2(out))
        out = self.conv3(out)
        return out

class BinResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes, bin_fn):
        super(BinResNet, self).__init__()
        self.bin_fn = bin_fn
        self.inflate = 4
        self.in_planes = 16 * self.inflate
        self.conv1 = nn.Conv2d(3, 16 * self.inflate, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16 * self.inflate)
        self.res1 = self._make_layer(block, 16 * self.inflate, num_blocks[0], stride=1)
        self.res2 = self._make_layer(block, 32 * self.inflate, num_blocks[1], stride=2)
        self.res3 = self._make_layer(block, 64 * self.inflate, num_blocks[2], stride=2)
        self.res_num = len(num_blocks)
        if self.res_num==4:
            self.res4 = self._make_layer(block, 128*self.inflate, num_blocks[3], stride=2)
            self.linear = nn.Linear(128 * self.inflate, num_classes)
        else:
            self.linear = nn.Linear(64 * self.inflate, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, self.bin_fn))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.bn1(self.conv1(x))
        out = binarize(out)
        out = self.res1(out)
        out = self.res2(out)
        out = self.res3(out)
        if self.res_num==4:
            out = self.res4(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return nn.Softmax(dim = -1)(out)

In [8]:
def test():
    net = ResNet(BasicBlock, [3, 3, 3], 1000)
    x = torch.rand(2, 3, 224, 224)
    y = net(x)
    print(net.state_dict().keys())
    print(y.shape)

def test_bin():
    binnet = BinResNet(BinBlock, [3, 3, 3], 100, bin_fn = "map").to(device)
    x = torch.rand(2, 3, 224, 224).to(device)
    # print(binnet)
    y = binnet(x)
    print(y.shape)

test()
test_bin()

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'res1.0.conv1.weight', 'res1.0.bn1.weight', 'res1.0.bn1.bias', 'res1.0.bn1.running_mean', 'res1.0.bn1.running_var', 'res1.0.bn1.num_batches_tracked', 'res1.0.conv2.weight', 'res1.0.bn2.weight', 'res1.0.bn2.bias', 'res1.0.bn2.running_mean', 'res1.0.bn2.running_var', 'res1.0.bn2.num_batches_tracked', 'res1.1.conv1.weight', 'res1.1.bn1.weight', 'res1.1.bn1.bias', 'res1.1.bn1.running_mean', 'res1.1.bn1.running_var', 'res1.1.bn1.num_batches_tracked', 'res1.1.conv2.weight', 'res1.1.bn2.weight', 'res1.1.bn2.bias', 'res1.1.bn2.running_mean', 'res1.1.bn2.running_var', 'res1.1.bn2.num_batches_tracked', 'res1.2.conv1.weight', 'res1.2.bn1.weight', 'res1.2.bn1.bias', 'res1.2.bn1.running_mean', 'res1.2.bn1.running_var', 'res1.2.bn1.num_batches_tracked', 'res1.2.conv2.weight', 'res1.2.bn2.weight', 'res1.2.bn2.bias', 'res1.2.bn2.running_mean', 'res1.2.bn2.running_var', 'res1.2.bn2.nu

In [9]:
import torchvision
import torchvision.transforms as transforms

# Image preprocessing modules
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                             train=True, 
                                             transform=transform,
                                             download=True)

train_set, val_set = torch.utils.data.random_split(train_dataset, [int(0.8 * len(train_dataset)), int(0.2 * len(train_dataset))])
test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                            train=False, 
                                            transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                           batch_size=128, 
                                           shuffle=True)

val_loader = torch.utils.data.DataLoader(dataset=val_set,
                                         batch_size=128,
                                         shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=128, 
                                          shuffle=False)        

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../../data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../../data/cifar-10-python.tar.gz to ../../data/


In [10]:
bin = ["res", "conv", "weight"]
notbin = ["conv1.weight", "linear.weight", "linear.bias"]
map = ["bin_fn"]

In [11]:
def create_bin_dict(binnet):
    bin_fns = {}
    bin_fns["res1.0.conv1"] = binnet.res1[0].conv1.bin_fn
    bin_fns["res1.0.conv2"] = binnet.res1[0].conv1.bin_fn
    bin_fns["res1.1.conv1"] = binnet.res1[1].conv1.bin_fn
    bin_fns["res1.1.conv2"] = binnet.res1[1].conv1.bin_fn
    bin_fns["res1.2.conv1"] = binnet.res1[2].conv1.bin_fn
    bin_fns["res1.2.conv2"] = binnet.res1[2].conv1.bin_fn
    bin_fns["res2.0.conv1"] = binnet.res2[0].conv1.bin_fn
    bin_fns["res2.0.conv2"] = binnet.res2[0].conv1.bin_fn
    bin_fns["res2.1.conv1"] = binnet.res2[1].conv1.bin_fn
    bin_fns["res2.1.conv2"] = binnet.res2[1].conv1.bin_fn
    bin_fns["res2.2.conv1"] = binnet.res2[2].conv1.bin_fn
    bin_fns["res2.2.conv2"] = binnet.res2[2].conv1.bin_fn
    bin_fns["res3.0.conv1"] = binnet.res3[0].conv1.bin_fn
    bin_fns["res3.0.conv2"] = binnet.res3[0].conv1.bin_fn
    bin_fns["res3.1.conv1"] = binnet.res3[1].conv1.bin_fn
    bin_fns["res3.1.conv2"] = binnet.res3[1].conv1.bin_fn
    bin_fns["res3.2.conv1"] = binnet.res3[2].conv1.bin_fn
    bin_fns["res3.2.conv2"] = binnet.res3[2].conv1.bin_fn
    return bin_fns

In [12]:
def learn(net, binnet):
    bin_fns = create_bin_dict(binnet)
    noisy_loss = torch.tensor([0.0], device=device)
    grad_aux_W = {}
    for (name_q, q), (name_p, p) in zip(binnet.named_parameters(), net.named_parameters()):
        if name_p==name_q:
            if all([i in name_p for i in bin]):
                W = p.data
                W.requires_grad = True
                Q = binarize(p.data)
                bin_fn  = bin_fns[name_p.replace(".weight", "")]
                Q_star = bin_fn(W)
                Q_star.to("cpu")
                aux_loss = torch.sum((1 - rhof) * nn.MSELoss()(Q, Q_star) - rhof * nn.MSELoss()(-Q, Q_star) / (1 - 2 * rhof)) # li = aux_loss
                aux_loss.backward() # should give d(li)/d(theta) theta - map params

                grad_aux_W[name_p] = W.grad.data # should store d(li)/dW
                noisy_loss += aux_loss

                assert q.data.shape == Q_star.shape
                q.data = Q_star
    
    return noisy_loss, binnet, grad_aux_W

In [13]:
def base_train(net, binnet):
    for (name_q, q), (name_p, p) in zip(binnet.named_parameters(), net.named_parameters()):
        if name_q==name_p:
            if all([i in name_p for i in bin]):
                q.data = p.data
            elif name_p in notbin:
                q.data = p.data

    return binnet

In [14]:
net = ResNet(BasicBlock, [3, 3, 3], 10).to(device)
binnet = BinResNet(BinBlock, [3, 3, 3], 10, bin_fn=binarize).to(device)
binnet_optim = torch.optim.SGD(binnet.parameters(), lr = 0.1, momentum = 0.9)
net_optim = torch.optim.SGD(net.parameters(), lr = 0.1, momentum =0.9)
rhof = 0.005

In [None]:
pretrain_losses = []
train_losses = []

In [None]:
num_pretrain_epochs=400
for epoch in range(num_pretrain_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # print(i)
        images = images.to(device)
        # print(images.shape)
        labels = labels.to(device)

        # Forward pass
        binnet = base_train(net, binnet)

        out = binnet(images)
        # print(out.shape)
        ce_loss = nn.CrossEntropyLoss()(out, labels)
        
        binnet_optim.zero_grad()
        net_optim.zero_grad()
        ce_loss.backward()
        nn.utils.clip_grad_value_(binnet.parameters(), 1)

        for p, (name, q) in zip(net.parameters(), binnet.named_parameters()):
            if q.grad is not None:
                p.grad = q.grad

        binnet_optim.step()
        net_optim.step()
        trainaccuracy = torch.sum(labels==torch.argmax(out, dim=-1)).to("cpu").item()/len(images)

    if (epoch+1)%1==0:
        with torch.no_grad():
            x = torch.tensor([0.0], device = device)
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                # print(labels.shape)
                out = binnet(images)
                # print(torch.argmax(out, dim = -1).shape)
                x += torch.sum(labels == torch.argmax(out, dim = -1))
            valaccuracy = x.to("cpu").item()/10000
        print(f"Epoch [{epoch+1}/{num_pretrain_epochs}] Loss = {ce_loss.item()} TrainAccuracy = {trainaccuracy} ValAccuracy = {valaccuracy}")
    
    if (epoch+1) % 10 == 0:
        torch.save({"epoch": epoch+1, "net_state_dict": net.state_dict(), "binnet_state_dict": binnet.state_dict()}, "/content/drive/My Drive/Artificial intelligence/bigmodel.th")
        print(f"{epoch+1} saved")

Epoch [1/400] Loss = 2.2652363777160645 TrainAccuracy = 0.125 ValAccuracy = 0.1912
Epoch [2/400] Loss = 2.2258291244506836 TrainAccuracy = 0.21875 ValAccuracy = 0.2069
Epoch [3/400] Loss = 2.2424449920654297 TrainAccuracy = 0.171875 ValAccuracy = 0.2262
Epoch [4/400] Loss = 2.218268632888794 TrainAccuracy = 0.296875 ValAccuracy = 0.228
Epoch [5/400] Loss = 2.2504844665527344 TrainAccuracy = 0.1875 ValAccuracy = 0.2378
Epoch [6/400] Loss = 2.1801698207855225 TrainAccuracy = 0.296875 ValAccuracy = 0.2416
Epoch [7/400] Loss = 2.232516050338745 TrainAccuracy = 0.1875 ValAccuracy = 0.2447
Epoch [8/400] Loss = 2.209232807159424 TrainAccuracy = 0.265625 ValAccuracy = 0.2475
Epoch [9/400] Loss = 2.200521230697632 TrainAccuracy = 0.234375 ValAccuracy = 0.253
Epoch [10/400] Loss = 2.1866190433502197 TrainAccuracy = 0.234375 ValAccuracy = 0.2549
10 saved
Epoch [11/400] Loss = 2.1807570457458496 TrainAccuracy = 0.265625 ValAccuracy = 0.2559
Epoch [12/400] Loss = 2.1015567779541016 TrainAccuracy = 

In [None]:
# if not trained to 400 epochs
x = torch.load("/content/drive/My Drive/Artificial intelligence/bigmodel.th")
net.load_state_dict(x["net_state_dict"])
binnet.load_state_dict(x["binnet_state_dict"])
for epoch in range(x["epoch"]-1, 400):
    for i, (images, labels) in enumerate(train_loader):
        # print(i)
        images = images.to(device)
        # print(images.shape)
        labels = labels.to(device)

        # Forward pass
        binnet = base_train(net, binnet)

        out = binnet(images)
        # print(out.shape)
        ce_loss = nn.CrossEntropyLoss()(out, labels)
        
        binnet_optim.zero_grad()
        net_optim.zero_grad()
        ce_loss.backward()
        nn.utils.clip_grad_value_(binnet.parameters(), 1)

        for p, (name, q) in zip(net.parameters(), binnet.named_parameters()):
            if q.grad is not None:
                p.grad = q.grad

        binnet_optim.step()
        net_optim.step()
        trainaccuracy = torch.sum(labels==torch.argmax(out, dim=-1)).to("cpu").item()/len(images)

    if (epoch+1)%1==0:
        with torch.no_grad():
            x = torch.tensor([0.0], device = device)
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                # print(labels.shape)
                out = binnet(images)
                # print(torch.argmax(out, dim = -1).shape)
                x += torch.sum(labels == torch.argmax(out, dim = -1))
            valaccuracy = x.to("cpu").item()/10000
        print(f"Epoch [{epoch+1}/400] Loss = {ce_loss.item()} TrainAccuracy = {trainaccuracy} ValAccuracy = {valaccuracy}")
    
    if (epoch+1) % 10 == 0:
        torch.save({"epoch": epoch+1, "net_state_dict": net.state_dict(), "binnet_state_dict": binnet.state_dict()}, "/content/drive/My Drive/Artificial intelligence/bigmodel.th")
        print(f"{epoch+1} saved")

Epoch [140/400] Loss = 2.078464984893799 TrainAccuracy = 0.375 ValAccuracy = 0.3779
140 saved
Epoch [141/400] Loss = 2.1387104988098145 TrainAccuracy = 0.34375 ValAccuracy = 0.3782
Epoch [142/400] Loss = 2.0908939838409424 TrainAccuracy = 0.3125 ValAccuracy = 0.3818
Epoch [143/400] Loss = 2.1459243297576904 TrainAccuracy = 0.328125 ValAccuracy = 0.3778
Epoch [144/400] Loss = 2.148730754852295 TrainAccuracy = 0.3125 ValAccuracy = 0.3778
Epoch [145/400] Loss = 2.117278575897217 TrainAccuracy = 0.3125 ValAccuracy = 0.3845
Epoch [146/400] Loss = 2.068174362182617 TrainAccuracy = 0.390625 ValAccuracy = 0.3828
Epoch [147/400] Loss = 2.054394483566284 TrainAccuracy = 0.390625 ValAccuracy = 0.3828
Epoch [148/400] Loss = 2.054574489593506 TrainAccuracy = 0.390625 ValAccuracy = 0.3917
Epoch [149/400] Loss = 2.0920724868774414 TrainAccuracy = 0.359375 ValAccuracy = 0.3834
Epoch [150/400] Loss = 2.097684144973755 TrainAccuracy = 0.34375 ValAccuracy = 0.3871
150 saved
Epoch [151/400] Loss = 2.12638

In [None]:
import numpy as np
x = torch.load("/content/drive/My Drive/Artificial intelligence/bigmodel.th")
binnet = BinResNet(BinBlock, [3, 3, 3], 10, bin_fn=binarize).to(device)
binnet.load_state_dict(x["binnet_state_dict"])

with torch.no_grad():
    losses = []
    x = 0
    for i, (images, labels) in enumerate(test_loader):
        # print(i)
        images = images.to(device)
        # print(images.shape)
        labels = labels.to(device)

        out = binnet(images)
        ce_loss = nn.CrossEntropyLoss()(out, labels)
        losses.append(ce_loss.item())
        x += torch.sum(labels==torch.argmax(out, dim=-1)).to("cpu").item()
    print(f"loss {np.mean(losses)} TestAccuracy {x/10000}")

loss 1.9883591887317127 TestAccuracy 0.4699


In [15]:
mapbinnet = BinResNet(BinBlock, [3, 3, 3], 10, bin_fn="map").to(device)
x = torch.load("/content/drive/My Drive/Artificial intelligence/bigmodel.th")
net = ResNet(BinBlock, [3, 3, 3], 10).to(device)
net.load_state_dict(x["net_state_dict"])
mapbinnet_optim = torch.optim.SGD(mapbinnet.parameters(), lr = 0.1, momentum=0.9)
# Copy the full precision weights
for (name_p, p), (name_q, q) in zip(net.named_parameters(), mapbinnet.named_parameters()):
    if name_p==name_q:
        q.data = p.data

In [None]:
# Some warm up for the mapping conv weights
num_warmup_epochs = 400
for epoch in range(num_warmup_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        out = mapbinnet(images)
        net_optim.zero_grad()
        mapbinnet_optim.zero_grad()

        noisy_loss, mapbinne, grad_aux_W = learn(net, mapbinnet)
        ce_loss = nn.CrossEntropyLoss()(out, labels)
        ce_loss.backward()
        nn.utils.clip_grad_value_(mapbinnet.parameters(), 1)

        for (name_p, p), (name_q, q) in zip(net.named_parameters(), mapbinnet.named_parameters()):
            if name_p==name_q:
                if q.grad is not None:
                    p.grad = q.grad
                    if all([i in name_p for i in bin]):
                        p.grad.data += grad_aux_W[name_p]

        net_optim.step()
        mapbinnet_optim.step()

        # Changing the value of the other parameters to the saved ones
        for (name_p, p), (name_q, q) in zip(net.named_parameters(), mapbinnet.named_parameters()):
            if name_p == name_q:
                q.data = p.data
        
        trainaccuracy = torch.sum(labels==torch.argmax(out, dim=-1)).to("cpu").item()/len(images)

    with torch.no_grad():
            x = torch.tensor([0.0], device = device)
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                # print(labels.shape)
                out = mapbinnet(images)
                # print(torch.argmax(out, dim = -1).shape)
                x += torch.sum(labels == torch.argmax(out, dim = -1))
            valaccuracy = x.to("cpu").item()/10000

    print(f"Epoch [{epoch+1}/{num_warmup_epochs}] Loss = {ce_loss.item()} TrainAccuracy = {trainaccuracy} ValAccuracy = {valaccuracy}")
    if (epoch+1) % 10 == 0:
        torch.save({"epoch": epoch+1, "net_state_dict": net.state_dict(), "mapbinnet_state_dict": mapbinnet.state_dict()}, "/content/drive/My Drive/Artificial intelligence/bigmap.th")
        print(f"{epoch+1} saved")


Epoch [1/400] Loss = 2.282268762588501 TrainAccuracy = 0.125 ValAccuracy = 0.102
Epoch [2/400] Loss = 2.282072067260742 TrainAccuracy = 0.140625 ValAccuracy = 0.102
Epoch [3/400] Loss = 2.311079502105713 TrainAccuracy = 0.125 ValAccuracy = 0.1093
Epoch [4/400] Loss = 2.3521077632904053 TrainAccuracy = 0.0625 ValAccuracy = 0.1139
Epoch [5/400] Loss = 2.3116023540496826 TrainAccuracy = 0.140625 ValAccuracy = 0.1082
Epoch [6/400] Loss = 2.300687313079834 TrainAccuracy = 0.171875 ValAccuracy = 0.1399
Epoch [7/400] Loss = 2.314526319503784 TrainAccuracy = 0.125 ValAccuracy = 0.1505
Epoch [8/400] Loss = 2.3149054050445557 TrainAccuracy = 0.109375 ValAccuracy = 0.1655
Epoch [9/400] Loss = 2.249946117401123 TrainAccuracy = 0.171875 ValAccuracy = 0.1676
Epoch [10/400] Loss = 2.335204839706421 TrainAccuracy = 0.09375 ValAccuracy = 0.1707
10 saved
Epoch [11/400] Loss = 2.2633652687072754 TrainAccuracy = 0.1875 ValAccuracy = 0.1762
Epoch [12/400] Loss = 2.240863084793091 TrainAccuracy = 0.171875 V

In [16]:
y = torch.rand(3, 3, 224, 224).to(device)
z = mapbinnet(y)
print(z.shape)

torch.Size([3, 10])


In [None]:
# if not trained till 400 epochs
mapbinnet_optim = torch.optim.SGD(mapbinnet.parameters(), lr = 0.1, momentum=0.9)
x = torch.load("/content/drive/My Drive/Artificial intelligence/bigmap.th")
net.load_state_dict(x["net_state_dict"])
mapbinnet.load_state_dict(x["mapbinnet_state_dict"])
num_warmup_epochs = 400
for epoch in range(x["epoch"], num_warmup_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        out = mapbinnet(images)
        net_optim.zero_grad()
        mapbinnet_optim.zero_grad()

        noisy_loss, mapbinne, grad_aux_W = learn(net, mapbinnet)
        ce_loss = nn.CrossEntropyLoss()(out, labels)
        ce_loss.backward()
        nn.utils.clip_grad_value_(mapbinnet.parameters(), 1)

        for (name_p, p), (name_q, q) in zip(net.named_parameters(), mapbinnet.named_parameters()):
            if name_p==name_q:
                if q.grad is not None:
                    p.grad = q.grad
                    if all([i in name_p for i in bin]):
                        p.grad.data += grad_aux_W[name_p]

        net_optim.step()
        mapbinnet_optim.step()

        # Changing the value of the other parameters to the saved ones
        for (name_p, p), (name_q, q) in zip(net.named_parameters(), mapbinnet.named_parameters()):
            if name_p == name_q:
                q.data = p.data
        
        trainaccuracy = torch.sum(labels==torch.argmax(out, dim=-1)).to("cpu").item()/len(images)

    with torch.no_grad():
            x = torch.tensor([0.0], device = device)
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                # print(labels.shape)
                out = mapbinnet(images)
                # print(torch.argmax(out, dim = -1).shape)
                x += torch.sum(labels == torch.argmax(out, dim = -1))
            valaccuracy = x.to("cpu").item()/10000

    print(f"Epoch [{epoch+1}/{num_warmup_epochs}] Loss = {ce_loss.item()} TrainAccuracy = {trainaccuracy} ValAccuracy = {valaccuracy}")
    if (epoch+1) % 10 == 0:
        torch.save({"epoch": epoch+1, "net_state_dict": net.state_dict(), "mapbinnet_state_dict": mapbinnet.state_dict()}, "/content/drive/My Drive/Artificial intelligence/bigmap.th")
        print(f"{epoch+1} saved")


Epoch [141/400] Loss = 2.224152088165283 TrainAccuracy = 0.21875 ValAccuracy = 0.2316
Epoch [142/400] Loss = 2.254455089569092 TrainAccuracy = 0.1875 ValAccuracy = 0.2192
Epoch [143/400] Loss = 2.196929693222046 TrainAccuracy = 0.203125 ValAccuracy = 0.2235
Epoch [144/400] Loss = 2.193085193634033 TrainAccuracy = 0.25 ValAccuracy = 0.2196
Epoch [145/400] Loss = 2.2387006282806396 TrainAccuracy = 0.21875 ValAccuracy = 0.2198
Epoch [146/400] Loss = 2.2287356853485107 TrainAccuracy = 0.1875 ValAccuracy = 0.2208
Epoch [147/400] Loss = 2.2239320278167725 TrainAccuracy = 0.21875 ValAccuracy = 0.2189
Epoch [148/400] Loss = 2.2119691371917725 TrainAccuracy = 0.203125 ValAccuracy = 0.2231
Epoch [149/400] Loss = 2.1483683586120605 TrainAccuracy = 0.328125 ValAccuracy = 0.2181
Epoch [150/400] Loss = 2.262892484664917 TrainAccuracy = 0.125 ValAccuracy = 0.2166
150 saved
Epoch [151/400] Loss = 2.2370314598083496 TrainAccuracy = 0.21875 ValAccuracy = 0.2187
Epoch [152/400] Loss = 2.2068018913269043 

In [25]:
mapbinnet = BinResNet(BinBlock, [3, 3, 3], 10, bin_fn="map").to(device)
x = torch.load("/content/drive/My Drive/Artificial intelligence/bigmap.th")
net = ResNet(BinBlock, [3, 3, 3], 10).to(device)
net.load_state_dict(x["net_state_dict"])
y = torch.rand(3, 3, 224, 224).to(device)
z = mapbinnet(y)
mapbinnet_optim = torch.optim.SGD(mapbinnet.parameters(), lr = 0.1, momentum=0.9)
net.load_state_dict(x["net_state_dict"])
mapbinnet.load_state_dict(x["mapbinnet_state_dict"])
scheduler = torch.optim.lr_scheduler.StepLR(mapbinnet_optim, step_size=5, gamma=0.1)

In [26]:
train_losses = []

In [27]:
# Fine tune with the noisy supervision
num_train_epochs = 120
for epoch in range(num_train_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        out = mapbinnet(images)
        net_optim.zero_grad()
        mapbinnet_optim.zero_grad()
        # map_optim.zero_grad()

        # Forward pass
        noisy_loss, mapbinnet, grad_aux_W = learn(net, mapbinnet)

        ce_loss = nn.CrossEntropyLoss()(out, labels)
        ce_loss.backward()
        nn.utils.clip_grad_value_(mapbinnet.parameters(), 1)

        for (name_p, p), (name_q, q) in zip(net.named_parameters(), mapbinnet.named_parameters()):
            if name_p==name_q:
                if q.grad is not None:
                    p.grad = q.grad
                    if all([i in name_p for i in bin]):
                        p.grad.data += grad_aux_W[name_p]

        net_optim.step()
        mapbinnet_optim.step()
        # map_optim.step()
        scheduler.step()
        train_losses.append(ce_loss.to("cpu").item())
        trainaccuracy = torch.sum(labels==torch.argmax(out, dim=-1)).to("cpu").item()/len(images)

    with torch.no_grad():
            x = torch.tensor([0.0], device = device)
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                # print(labels.shape)
                out = mapbinnet(images)
                # print(torch.argmax(out, dim = -1).shape)
                x += torch.sum(labels == torch.argmax(out, dim = -1))
            valaccuracy = x.to("cpu").item()/10000

    print(f"Epoch [{epoch+1}/{num_train_epochs}] Loss = {train_losses[-1]} TrainAccuracy = {trainaccuracy} ValAccuracy = {valaccuracy}")
    if (epoch+1) % 10 == 0:
        torch.save({"epoch": epoch+1, "net_state_dict": net.state_dict(), "mapbinnet_state_dict": mapbinnet.state_dict()}, "/content/drive/My Drive/Artificial intelligence/mapbigmodel.th")
        print(f"{epoch+1} saved")

Epoch [1/120] Loss = 2.295302152633667 TrainAccuracy = 0.078125 ValAccuracy = 0.2262
Epoch [2/120] Loss = 2.2043542861938477 TrainAccuracy = 0.21875 ValAccuracy = 0.2254
Epoch [3/120] Loss = 2.1670751571655273 TrainAccuracy = 0.28125 ValAccuracy = 0.2282
Epoch [4/120] Loss = 2.2343382835388184 TrainAccuracy = 0.21875 ValAccuracy = 0.2247
Epoch [5/120] Loss = 2.2555980682373047 TrainAccuracy = 0.1875 ValAccuracy = 0.2231
Epoch [6/120] Loss = 2.2341737747192383 TrainAccuracy = 0.203125 ValAccuracy = 0.2261
Epoch [7/120] Loss = 2.1567904949188232 TrainAccuracy = 0.3125 ValAccuracy = 0.2253
Epoch [8/120] Loss = 2.188002586364746 TrainAccuracy = 0.25 ValAccuracy = 0.2238
Epoch [9/120] Loss = 2.2232489585876465 TrainAccuracy = 0.21875 ValAccuracy = 0.2281
Epoch [10/120] Loss = 2.2038397789001465 TrainAccuracy = 0.265625 ValAccuracy = 0.2258
10 saved
Epoch [11/120] Loss = 2.1856446266174316 TrainAccuracy = 0.265625 ValAccuracy = 0.2225
Epoch [12/120] Loss = 2.2313175201416016 TrainAccuracy = 

KeyboardInterrupt: ignored