In [2]:
import jittor as jt
import pygmtools as pygm
from jittor.optim import Optimizer
from jittor import nn
from jittor import Module
import random
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from jittor.dataset.cifar import CIFAR10
from jittor.dataset import DataLoader
import jittor.transform as trans

In [3]:
jt.flags.use_cuda=1

[38;5;2m[i 0402 19:32:27.163587 00 cuda_flags.cc:49] CUDA enabled.[m


In [4]:
'''Resnet definition'''
'''Todo'''
class Resnet(Module):
    def __init__(self):
        super(Resnet,self).__init__()
        self.conv1=nn.Conv(3,8,3,1,1)
        self.bn1=nn.BatchNorm(8)
        self.relu=nn.Relu()
        self.maxpool=nn.MaxPool2d(2)
        self.conv2=nn.Conv(8,16,3,1,1)
        self.bn2=nn.BatchNorm(16)
        #relu
        #maxpool
        self.conv3=nn.Conv(16,32,3,1,1)
        self.bn3=nn.BatchNorm(32)
        #relu
        #maxpool
        self.conv4=nn.Conv(32,64,3,1,1)
        self.bn4=nn.BatchNorm(64)
        #relu
        #maxpool
        self.conv5=nn.Conv(64,128,3,1,1)
        self.bn5=nn.BatchNorm(128)
        #relu
        self.flatten=nn.Flatten()
        self.l1=nn.Linear(128*4*4,512)
        self.bn6=nn.BatchNorm(512)
        #relu
        self.fc=nn.Linear(512,16)
        #softmax
    def execute(self,x):
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv2(x)
        x=self.bn2(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv3(x)
        x=self.bn3(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv4(x)
        x=self.bn4(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv5(x)
        x=self.bn5(x)
        x=self.relu(x)
        x=self.flatten(x)
        x=self.l1(x)
        x=self.bn6(x)
        x=self.relu(x)
        x=self.fc(x)
        x=jt.reshape(x,(-1,4,4))
        x=pygm.linear_solvers.sinkhorn(x)
        #x=nn.softmax(x,dim=1)
        return x

In [5]:
def train_image_load(train_data):
    batch_size=0
    imgs=[]# for image to all images set
    for data in train_data:
        img,target=data
        batch_size+=1
        for i in range(len(img)):
            image=[]
            '''Containing all 4 parts.'''
            image.append(img[i].permute(2,1,0)[:,0:32,0:32])
            image.append(img[i].permute(2,1,0)[:,32:32*2,0:32])
            image.append(img[i].permute(2,1,0)[:,0:32,32:32*2])
            image.append(img[i].permute(2,1,0)[:,32:32*2,32:32*2])
            imgs.append(image)
        if batch_size%4==0: #here change batch_size
            imgs=np.array(imgs)
            imgs= jt.Var(imgs).float32()
            yield imgs
            imgs=[]
    #imgs=jt.Var(imgs).float32()
    #return imgs
    


def target_generation(images):
    '''Randomly shuffle permutation of image,Generate target'''
    rearranged_images=[]
    targets=[]
    for i in range(len(images)):
        permute=np.random.permutation(4)[:4]
        rearranged_img=[]
        target=np.zeros((4,4))
        for j in range(len(images[i])):
            rearranged_img.append(images[i][permute[j]])
            target[j][permute[j]]=1
        rearranged_img=jt.Var(rearranged_img).permute(1,0,2,3)
        rearranged_img=np.reshape(rearranged_img,(3,64,64))
        rearranged_images.append(rearranged_img)
        targets.append(target)
    
    rearranged_images,targets=np.array(rearranged_images),np.array(targets)
    rearranged_images,targets=jt.Var(rearranged_images).float32(),jt.Var(targets).float32()
    return rearranged_images,targets

        
        
def train(net,optimizer,train_data_loader,epoch):
    net.train()
    train_step=0
    total_loss=0
    for image in train_image_load(train_data_loader):
        inputs,targets=target_generation(image)     #(64,3,64,64) vs (64,4,4)
        for i in range(4):
            outputs=net(inputs)
            loss=pygm.utils.permutation_loss(outputs,targets)
            optimizer.step(loss)
            train_step+=1
            total_loss+=loss
            if train_step%500==0:
                print(f'epoch:{epoch},Step:{train_step},Loss:{loss}')
                #format_text=f"epoch:{epoch},Step:{train_step},Loss:{loss}\n"
                #file.write(format_text)
    return total_loss/train_step


def test_image_load(test_data):
    batch_size=0
    imgs=[]
    for data in test_data:
        img,target=data
        batch_size+=1
        for i in range(len(img)):
            image=[]
            '''Containing all 4 parts.'''
            image.append(img[i].permute(2,1,0)[:,0:32,0:32])
            image.append(img[i].permute(2,1,0)[:,32:32*2,0:32])
            image.append(img[i].permute(2,1,0)[:,0:32,32:32*2])
            image.append(img[i].permute(2,1,0)[:,32:32*2,32:32*2])
            imgs.append(image)
            
        imgs=np.array(imgs)
        imgs=jt.Var(imgs).float32()
        yield imgs
        imgs=[]


def test_target_generation(images):
    '''Randomly shuffle permutation of image,Generate target'''
    rearranged_images=[]
    targets=[]
    for i in range(len(images)):
        permute=np.random.permutation(4)[:4]
        rearranged_img=[]
        target=np.zeros((4,4))
        for j in range(len(images[i])):
            rearranged_img.append(images[i][permute[j]])
            target[j][permute[j]]=1
        rearranged_img=jt.Var(rearranged_img).permute(1,0,2,3)
        rearranged_img=np.reshape(rearranged_img,(3,64,64))
        rearranged_images.append(rearranged_img)
        targets.append(target)
    rearranged_images,targets=np.array(rearranged_images),np.array(targets)
    rearranged_images,targets=jt.Var(rearranged_images).float32(),jt.Var(targets).float32()
    return rearranged_images,targets

def eval(outputs,target_i):
    acc=0
    for i in range(len(outputs)):
        pred=jt.argmax(outputs[i],1)
        real=jt.argmax(target_i[i],1)
        for j in range(len(pred[0])):
            if pred[0][j]==real[0][j]:
                acc+=1
    return acc

def test(net,optimizer,test_data_loader,epoch):
    test_step=0
    total_acc=0
    net.eval()
    for image in test_image_load(test_data_loader):
        inputs,targets=test_target_generation(image)
        outputs=net(inputs) # output(64,4),target(64,4,4)
        acc=eval(outputs,targets)
        total_acc+=acc/64.0
        test_step+=1
        if test_step%100==0:
            print(f'epoch:{epoch},Step:{test_step},Accuracy:{acc/64.0*100}%')
            #format_text=f'epoch:{epoch},Step:{test_step},Accuracy:{total_acc/(len(outputs)*4)*100}%\n'
            #file.write(format_text)
    
    print(f'\n epoch:{epoch},Accuracy:{total_acc/test_step*100}%\n')
    #format_text=f'\n epoch:{epoch},Accuracy:{overall_acc/test_step*100}%\n'
    #file.write(format_text)
    return total_acc/test_step*100

In [6]:
'''The target and the full set of images have completed'''
net=Resnet()
learning_rate=1e-5
optimizer=nn.SGD(net.parameters(),lr=learning_rate,momentum=0.9)
scheduler = jt.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
# 1. Get the train data
train_data=CIFAR10(train=True,transform=trans.RandomResizedCrop((32*2,32*2)))
train_data_loader=DataLoader(train_data,batch_size=16)
test_data=CIFAR10(train=False,transform=trans.RandomResizedCrop((32*2,32*2)))
test_data_loader=DataLoader(test_data,batch_size=16)
epochs=int(1e5)
train_loss=[]
test_acc=[]
pygm.set_backend('jittor')
epoch=0                         # change here after the kernel break.
path_pkl="/root/perm_jittor.pkl"
path_p="/root/epoch_jittor.p"

Files already downloaded and verified
Files already downloaded and verified


In [11]:
# Not run here anymore.
epoch_dict={"epoch":epoch}
jt.save(epoch_dict,path_p)
net.save(path_pkl)

In [7]:
while epoch < epochs:
    checkpoint=jt.load(path_p)
    if checkpoint['epoch'] > 0:
        epoch=checkpoint['epoch']
    print("epoch: ",epoch+1)
    net.load(path_pkl)
    train_loss.append(train(net,optimizer,train_data_loader,epoch+1))
    test_acc.append(test(net,optimizer,test_data_loader,epoch+1))
    scheduler.step()
    epoch+=1
    epoch_dict={"epoch":epoch}
    jt.save(epoch_dict,path_p)
    net.save(path_pkl)

epoch:  85
epoch:85,Step:500,Loss:0.8513048887252808
epoch:85,Step:1000,Loss:0.9510497450828552
epoch:85,Step:1500,Loss:0.9607932567596436
epoch:85,Step:2000,Loss:0.914334774017334
epoch:85,Step:2500,Loss:0.9032962322235107
epoch:85,Step:3000,Loss:1.005419373512268
epoch:85,Step:100,Accuracy:71.875%
epoch:85,Step:200,Accuracy:90.625%
epoch:85,Step:300,Accuracy:92.1875%
epoch:85,Step:400,Accuracy:79.6875%
epoch:85,Step:500,Accuracy:89.0625%
epoch:85,Step:600,Accuracy:82.8125%

 epoch:85,Accuracy:80.595%

epoch:  86
epoch:86,Step:500,Loss:0.9541069269180298
epoch:86,Step:1000,Loss:0.9439809322357178
epoch:86,Step:1500,Loss:0.9454079866409302
epoch:86,Step:2000,Loss:0.9303105473518372
epoch:86,Step:2500,Loss:0.8985252976417542
epoch:86,Step:3000,Loss:0.8568313717842102
epoch:86,Step:100,Accuracy:81.25%
epoch:86,Step:200,Accuracy:82.8125%
epoch:86,Step:300,Accuracy:85.9375%
epoch:86,Step:400,Accuracy:76.5625%
epoch:86,Step:500,Accuracy:85.9375%
epoch:86,Step:600,Accuracy:84.375%

 epoch:86

KeyboardInterrupt: 

In [None]:
plt.plot(test_acc,'r',label="test_acc")
plt.xlabel("Epochs")
plt.ylabel("acc")
plt.legend()
plt.show()

plt.plot(train_loss,'g',label="train_loss")
plt.xlabel("Epochs")
plt.ylabel("loss")
plt.legend()
plt.show()