In [None]:
import torch
from torch.utils.data import DataLoader as dl
from torch import nn

from torchvision import datasets as dt
from torchvision.transforms import ToTensor,Lambda,Compose
import matplotlib.pyplot as plt

tr = dt.FashionMNIST(root='fa',download=True,train=True ,transform =ToTensor() )
te = dt.FashionMNIST(root='fa',download=True,train=False,transform =ToTensor() )

trdl = dl(tr,128)
tedl = dl(te,128)

class ypnet(nn.Module):
    def __init__(self):
        super(ypnet,self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784,392),
            nn.ReLU(),
            nn.Linear(392,156),
            nn.ReLU(),
            nn.Linear(156,78),
            nn.ReLU(),
            nn.Linear(78,10),
            nn.Softmax(dim=1)
        )
    def forward(self,x):
        return self.net(x)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
yp = ypnet().to(device)
loss = nn.CrossEntropyLoss()
opt = torch.optim.Adam(yp.parameters() )

def train(dl,model,lossf,opt):
    for x,y in dl:
        x,y = x.to(device),y.to(device)
        pre = model(x)
        loss = lossf(pre,y)

        opt.zero_grad()
        loss.backward()
        opt.step()

def test(dl,model,lossf):
    model.eval() # eval scope
    size,losses,corrects = len(dl.dataset),0,0

    with torch.no_grad():
        for x,y in dl:
            x,y = x.to(device),y.to(device)
            pre = model(x)
            loss = lossf(pre,y)

            losses += loss.item()
            corrects += (pre.argmax(1)==y).type(torch.float).sum().item()
    print(f'accuracy {corrects/size} loss {losses/size}')
for _ in range(10):
    train(trdl,yp,loss,opt)
    test(tedl,yp,loss)

### v2  
- model
- save&load

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader as dl

from torchvision import datasets as dt
from torchvision.transforms import ToTensor,Lambda , Compose
import matplotlib.pyplot as plt

tr = dt.FashionMNIST(root='fa',download=True,train=True,transform=ToTensor() )
te = dt.FashionMNIST(root='fa',download=True,train=False,transform=ToTensor() )
trdl = dl(tr,256)
tedl = dl(te,256)

class cons(nn.Module):
    def __init__(self):
        super(cons,self).__init__()
        self.front = nn.Sequential(
            nn.Conv2d(1,32,3),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(32),
            nn.ReLU() ,#batchnorm 2d ?
            nn.Conv2d(32,16,2),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Flatten()
        )
        sple = torch.rand(1,1,28,28)
        front_out = self.front(sple).size()[-1]
        self.body = nn.Sequential(
            nn.Linear(front_out,256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(p=0.2),
            nn.Linear(256,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            # nn.Linear(128,10),
        )
        self.fc = nn.Sequential(
            nn.Linear(128,10),
            nn.Softmax(dim=1)
        )
    def forward(self,x):
        x = self.front(x)
        x = self.body(x)
        x = self.fc(x)
        return x
device = 'cuda' if torch.cuda.is_available() else 'cpu'
con0 = cons().to(device)
loss = nn.CrossEntropyLoss()
opt  = torch.optim.Adam(con0.parameters() )

def train(dl,model,lossf,opt):
    for x,y in dl:
        x,y = x.to(device),y.to(device)
        pre = model(x)
        loss = lossf(pre,y)

        opt.zero_grad()
        loss.backward()
        opt.step()

def test(dl,model,lossf):
    model.eval()
    size, acc , losses = len(dl.dataset) ,0,0
    with torch.no_grad():
        for x,y in dl:
            x,y = x.to(device),y.to(device)
            pre = model(x)
            loss = lossf(pre,y)

            acc += (pre.argmax(1)==y).type(torch.float).sum().item()
            losses += loss.item()
    print(f'{acc/size} : {losses/size}')
for _ in range(12):
    train(trdl,con0,loss,opt)
    test(tedl,con0,loss)

torch.save(con0,'con2l3v1.pth')
model = torch.load('con2l3v1.pth')

model = torch.load('con2l3v1.pth')
model.eval()
with torch.no_grad():
    classes = dt.FashionMNIST.classes
    for i in range(10):
        x,y = te[i][0],te[i][1]
        pre = model(x.unsqueeze(0))

        print(f'{classes[y]} : {classes[pre.argmax(1)]}')

### v3
- transfer learning (as a feature extractor)
- train_val (train+val)

In [None]:
def train_val(model,iter=8):
    lossf = nn.CrossEntropyLoss()
    opt   = torch.optim.Adam(model.parameters() )

    for _ in range(iter):
        train(trdl,model,lossf,opt)
        test(tedl,model,lossf)

In [None]:
class trnet(nn.Module):
    def __init__(self):
        super(trnet,self).__init__()
        self.front = nn.Sequential(
            nn.Conv2d(1,10,3),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(10),
            nn.ReLU(),

            nn.Conv2d(10,5,3),
            nn.BatchNorm2d(5),
            nn.Conv2d(5,2,2),
            nn.BatchNorm2d(2),
            nn.ReLU(),
            
            nn.Flatten()
        )
        sple = torch.rand(1,1,28,28)
        out  = self.front(sple).size()[-1]
        self.fc = nn.Linear(out,10)
    
    def forward(self,x):
        x = self.front(x)
        x = self.fc(x)
        return x
tr = trnet().to(device)
train_val(tr)

In [None]:
torch.save(tr,'tr')
class trwrap(nn.Module):
    def __init__(self):
        super(trwrap,self).__init__()
        self.conv = torch.load('tr')
        for param in self.conv.parameters():
            param.requires_grad = False
        
        sple = torch.rand(1,1,28,28)
        out = self.conv.front(sple).size()[-1]
        self.conv.fc = nn.Sequential( 
            nn.Linear(out,30),
            nn.BatchNorm1d(30),
            nn.ReLU()
        )

        self.kill = nn.Sequential(
            nn.Linear(30,30),
            nn.BatchNorm1d(30),
            nn.ReLU(),
            nn.Linear(30,10),
            nn.Softmax()
        )
        self.body = nn.Sequential(
            nn.Linear(30,30),
            nn.BatchNorm1d(30),
            nn.ReLU(),
            nn.Linear(30,10),
            nn.Softmax()
        )
    def forward(self,x):
        x = self.conv(x)
        x = self.body(x)
        return x 
wrap = trwrap().to(device)
train_val(wrap)