In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim 
import torch.utils.data as Data  # to make Loader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt 
import os
import numpy as np 
import time
import csv
import pygmtools as pygm

In [17]:
class Resnet(nn.Module):
    def __init__(self):
        super(Resnet,self).__init__()
        self.conv1=nn.Conv2d(3,8,3,1,1)
        self.bn1=nn.BatchNorm2d(8)
        self.relu=nn.ReLU()
        self.maxpool=nn.MaxPool2d(2)
        self.conv2=nn.Conv2d(8,16,3,1,1)
        self.bn2=nn.BatchNorm2d(16)
        #relu
        #maxpool
        self.conv3=nn.Conv2d(16,32,3,1,1)
        self.bn3=nn.BatchNorm2d(32)
        #relu
        #maxpool
        self.conv4=nn.Conv2d(32,64,3,1,1)
        self.bn4=nn.BatchNorm2d(64)
        #relu
        #maxpool
        self.conv5=nn.Conv2d(64,128,3,1,1)
        self.bn5=nn.BatchNorm2d(128)
        #relu
        self.flatten=nn.Flatten()
        self.l1=nn.Linear(128*4*4,512)
        self.bn6=nn.BatchNorm1d(512)
        #relu
        self.fc=nn.Linear(512,16)
        #softmax
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv2(x)
        x=self.bn2(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv3(x)
        x=self.bn3(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv4(x)
        x=self.bn4(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv5(x)
        x=self.bn5(x)
        x=self.relu(x)
        x=self.flatten(x)
        x=self.l1(x)
        x=self.bn6(x)
        x=self.relu(x)
        x=self.fc(x)
        x=torch.reshape(x,(-1,4,4))
        x=pygm.linear_solvers.sinkhorn(x)
        #x=nn.softmax(x,dim=1)
        return x






def train_image_load(train_data):
    batch_size=0
    imgs=[]# for image to all images set
    for data in train_data:
        img,target=data
        img=img.numpy()
        batch_size+=1
        for i in range(len(img)):
            image=[]
            '''Containing all 4 parts.'''
            image.append(img[i][:,0:32,0:32])
            image.append(img[i][:,32:32*2,0:32])
            image.append(img[i][:,0:32,32:32*2])
            image.append(img[i][:,32:32*2,32:32*2])
            imgs.append(image)
        if batch_size%4==0: #here change batch_size
            imgs=np.array(imgs)
            imgs= torch.tensor(imgs)
            yield imgs
            imgs=[]
    #imgs=jt.Var(imgs).float32()
    #return imgs
    


def target_generation(images):
    '''Randomly shuffle permutation of image,Generate target'''
    images=images.numpy()
    rearranged_images=[]
    targets=[]
    for i in range(len(images)):
        permute=np.random.permutation(4)[:4]
        rearranged_img=[]
        target=np.zeros((4,4))
        for j in range(len(images[i])):
            rearranged_img.append(images[i][permute[j]])
            target[j][permute[j]]=1
        #rearranged_img=torch.tensor(rearranged_img)
        rearranged_img=np.reshape(rearranged_img,(3,64,64))
        rearranged_images.append(rearranged_img)
        targets.append(target)
    rearranged_images,targets=np.array(rearranged_images),np.array(targets)
    rearranged_images,targets=torch.tensor(rearranged_images),torch.tensor(targets)
    return rearranged_images,targets

        
        
def train(net,optimizer,train_data_loader,epoch,device):
    net.train()
    train_step=0
    total_loss=0
    for image in train_image_load(train_data_loader):
        inputs,targets=target_generation(image)     #(64,3,64,64) vs (64,4,4)
        for i in range(4):
            inputs,targets=inputs.float().to(device),targets.float().to(device)
            outputs=net(inputs)
            optimizer.zero_grad()
            outputs,targets=outputs.float(),targets.float()
            loss=pygm.utils.permutation_loss(outputs,targets)
            loss.backward()
            optimizer.step()
            train_step+=1
            total_loss+=loss
            if train_step%500==0:
                print(f'epoch:{epoch},Step:{train_step},Loss:{loss}')
                #format_text=f"epoch:{epoch},Step:{train_step},Loss:{loss}\n"
                #file.write(format_text)
    return total_loss/train_step


def test_image_load(test_data):
    batch_size=0
    imgs=[]
    for data in test_data:
        img,target=data
        img=img.numpy()
        batch_size+=1
        for i in range(len(img)):
            image=[]
            '''Containing all 4 parts.'''
            image.append(img[i][:,0:32,0:32])
            image.append(img[i][:,32:32*2,0:32])
            image.append(img[i][:,0:32,32:32*2])
            image.append(img[i][:,32:32*2,32:32*2])
            imgs.append(image)
        imgs=np.array(imgs)
        imgs=torch.tensor(imgs)
        yield imgs
        imgs=[]


def test_target_generation(images):
    '''Randomly shuffle permutation of image,Generate target'''
    images=images.numpy()
    rearranged_images=[]
    targets=[]
    for i in range(len(images)):
        permute=np.random.permutation(4)[:4]
        rearranged_img=[]
        target=np.zeros((4,4))
        for j in range(len(images[i])):
            rearranged_img.append(images[i][permute[j]])
            target[j][permute[j]]=1
        #rearranged_img=torch.tensor(rearranged_img)
        rearranged_img=np.reshape(rearranged_img,(3,64,64))
        rearranged_images.append(rearranged_img)
        targets.append(target)
    rearranged_images,targets=np.array(rearranged_images),np.array(targets)
    rearranged_images,targets=torch.tensor(rearranged_images),torch.tensor(targets)
    return rearranged_images,targets

def eval(outputs,target_i):
    acc=0
    for i in range(len(outputs)):
        pred=torch.argmax(outputs[i],1)
        real=torch.argmax(target_i[i],1)
        for j in range(len(pred)):
            if pred[j]==real[j]:
                acc+=1
    return acc

def test(net,optimizer,test_data_loader,epoch,device):
    test_step=0
    total_acc=0
    net.eval()
    for image in test_image_load(test_data_loader):
        inputs,targets=test_target_generation(image)
        inputs,targets=inputs.float().to(device),targets.float().to(device)
        outputs=net(inputs) # output(64,4),target(64,4,4)
        acc=eval(outputs,targets)
        total_acc+=acc/64.0
        test_step+=1
        if test_step%100==0:
            print(f'epoch:{epoch},Step:{test_step},Accuracy:{acc/64.0*100}%')
            #format_text=f'epoch:{epoch},Step:{test_step},Accuracy:{total_acc/(len(outputs)*4)*100}%\n'
            #file.write(format_text)
    
    print(f'\n epoch:{epoch},Accuracy:{total_acc/test_step*100}%\n')
    #format_text=f'\n epoch:{epoch},Accuracy:{overall_acc/test_step*100}%\n'
    #file.write(format_text)
    return total_acc/test_step*100

In [14]:
'''The target and the full set of images have completed'''
device = torch.device("cuda")
pygm.set_backend('pytorch')
net=Resnet().to(device)
learning_rate=1e-5
optimizer=optim.SGD(net.parameters(),lr=learning_rate,momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
# 1. Get the train data
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),transforms.RandomResizedCrop((32*2,32*2),antialias=True)])
train_data = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
train_data_loader = Data.DataLoader(
    train_data,
    batch_size=16,
    shuffle=True,
    # num_workers=2 # ready to be commented(windows)
)
test_data = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform,
)
test_data_loader = Data.DataLoader(
    test_data,
    batch_size=16,
    shuffle=False,
    # num_workers=2
)
epochs=int(1e5)
epoch=0
train_epoch_loss=0.0
path="/root/pytorch_model.pt"
train_loss=[]
test_acc=[]
#torch.save({'epoch': 0,'model_state_dict': net.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss':0.0},path)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
while epoch < epochs:
    checkpoint=torch.load(path)
    epoch=checkpoint['epoch']
    if epoch > 0:
        #loss=checkpoint['loss']
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])   # load data from checkpoint if has trained.
    print("epoch: ",epoch+1)
    train_loss.append(train(net,optimizer,train_data_loader,epoch+1,device))
    test_acc.append(test(net,optimizer,test_data_loader,epoch+1,device))
    scheduler.step()
    params=net.state_dict()
    epoch+=1
    torch.save({'epoch': epoch,'model_state_dict': net.state_dict(),'optimizer_state_dict': optimizer.state_dict()},path)

epoch:  38
epoch:38,Step:500,Loss:1.8922981023788452
epoch:38,Step:1000,Loss:1.7212601900100708
epoch:38,Step:1500,Loss:1.9927786588668823
epoch:38,Step:2000,Loss:1.8845889568328857
epoch:38,Step:2500,Loss:1.898523211479187
epoch:38,Step:3000,Loss:1.8854297399520874
epoch:38,Step:100,Accuracy:40.625%
epoch:38,Step:200,Accuracy:60.9375%
epoch:38,Step:300,Accuracy:35.9375%
epoch:38,Step:400,Accuracy:51.5625%
epoch:38,Step:500,Accuracy:32.8125%
epoch:38,Step:600,Accuracy:48.4375%

 epoch:38,Accuracy:50.2675%

epoch:  39
epoch:39,Step:500,Loss:1.8303301334381104
epoch:39,Step:1000,Loss:1.8696353435516357
epoch:39,Step:1500,Loss:1.8749390840530396
epoch:39,Step:2000,Loss:1.8595794439315796
epoch:39,Step:2500,Loss:1.7932077646255493
epoch:39,Step:3000,Loss:1.8682003021240234
epoch:39,Step:100,Accuracy:56.25%
epoch:39,Step:200,Accuracy:62.5%
epoch:39,Step:300,Accuracy:56.25%
epoch:39,Step:400,Accuracy:59.375%
epoch:39,Step:500,Accuracy:53.125%
epoch:39,Step:600,Accuracy:50.0%

 epoch:39,Accur

: 

In [None]:
plt.plot(test_acc,'r',label="test_acc")
plt.xlabel("Epochs")
plt.ylabel("acc")
plt.legend()
plt.show()

plt.plot(train_loss,'g',label="train_loss")
plt.xlabel("Epochs")
plt.ylabel("loss")
plt.legend