In [188]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.hub import load_state_dict_from_url

In [189]:
class RESNET(nn.Module):
    def __init__(self,block,n,arch,num_classes=10,init_weight=True,**kwargs):
        super(RESNET,self).__init__()
        self.in_channels=64
        self.conv1=nn.Conv2d(in_channels=1,out_channels=self.in_channels,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn1=nn.BatchNorm2d(self.in_channels)
        self.relu=nn.ReLU(inplace=True)
        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        if arch=='18' or arch =='34':
            shortcuts=[False,True,True,True]
        else:
            shortcuts=[True,True,True,True]
            
        self.layer1=self.make_layer(block,64,n[0],shortcuts[0],1)
        self.layer2=self.make_layer(block,128,n[1],shortcuts[1],2)
        self.layer3=self.make_layer(block,256,n[2],shortcuts[2],2)
        self.layer4=self.make_layer(block,512,n[2],shortcuts[3],2)
        
        self.avgpool=nn.AdaptiveAvgPool2d((1,1))
        self.fc=nn.Linear(512*block.expansion,num_classes)
        
    def forward(self,x):
        
        x=self.relu(self.bn1(self.conv1(x)))
        x=self.maxpool(x)
        
        x=self.layer1(x)
        x=self.layer2(x)
        x=self.layer3(x)
        x=self.layer4(x)
        
        x=self.avgpool(x)
        x=x.view(x.size(0), -1)
        x=self.fc(x)
        
        return x
        
    def make_layer(self,block,out_channels,sub_layers,sampling,strides):
        layer=[]
        for i in range(sub_layers):
            if sampling==True and i==0:
                layer.append(block(self.in_channels,out_channels,True,stride=strides))
                self.in_channels=out_channels*block.expansion
            elif sampling==True and i>0:
                layer.append(block(self.in_channels,out_channels,False,stride=1))
                self.in_channels=out_channels*block.expansion
            elif sampling==False:
                layer.append(block(self.in_channels,out_channels,False))
                self.in_channels=out_channels*block.expansion
        return nn.Sequential(*layer)

In [190]:
def resnet18(pretrain=False,progress=True,**kwargs):
    return _resnet(BasicBlock,[2,2,2,2],'18',pretrain,progress,**kwargs)

In [191]:
def _resnet(block,n,arch,pretrain,progress,**kwargs):
    if pretrain:
        kwargs['init_weight']=False
    model=RESNET(block,n,arch,**kwargs)
    
    if pretrain:
        state_dict=load_state_dict_from_url(model_urls[arch],progress)
        model.load_state_dict(state_dict)
    
    return model

In [192]:
class BasicBlock(nn.Module):
    expansion=1
    def __init__(self,in_channels,out_channels,sampling=False,stride=1):
        super(BasicBlock,self).__init__()
        self.conv1=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=stride,padding=1,bias=False)
        self.bn1=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU(inplace=True)
        self.conv2=nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn2=nn.BatchNorm2d(out_channels)
        
        self.sampling=sampling
        if self.sampling:
            self.downsample=nn.Sequential(
                nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=2,padding=0,bias=False),
                nn.BatchNorm2d(out_channels)
            )
    def forward(self,x):
        identity=x.clone()
        out=self.conv1(x)
        out=self.bn1(out)
        out=self.relu(out)
        out=self.conv2(out)
        out=self.bn2(out)
        
        if self.sampling:
            identity=self.downsample(identity)
            
        out+=identity
        out=self.relu(out)
        return out

In [193]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [194]:
transform_config=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()    
])
#transforms.Lambda(lambda x: x.repeat(3, 1, 1))

In [195]:
train_dataset=datasets.FashionMNIST('/home/ubuntu/gpu_work',download=True,train=True,transform=transform_config)
test_dataset=datasets.FashionMNIST('/home/ubuntu/gpu_work',download=True,train=False,transform=transform_config)


In [196]:
BATCH_SIZE=256
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=256,shuffle=True)
test_loader=torch.utils.data.DataLoader(test_dataset,batch_size=256,shuffle=True)

In [197]:
model=resnet18().to(device)
optimizer=optim.Adam(params=model.parameters(),lr=0.0001)
loss_fn=nn.CrossEntropyLoss()

In [198]:
def train(model,train_loader,optimizer,device,epoch):
    model.train()
    for batch_ids, (data,label) in enumerate(train_loader):
        label=label.type(torch.LongTensor)
        data,label=data.to(device),label.to(device)
        
        optimizer.zero_grad()
        model_output=model(data)
        loss=loss_fn(model_output,label)
        loss.backward()
        optimizer.step()
        
        if (batch_ids+1)%50 == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                epoch,batch_ids*len(data),len(train_loader.dataset),
                100.*batch_ids/len(train_loader),loss.item()))

In [201]:
def test(model,test_loader,device):
    model.eval()
    correct=0
    with torch.no_grad():
        for data,label in test_loader:
            data,label=data.to(device),label.to(device)
            y_hat=model(data)
            _,y_pred=torch.max(y_hat,1)
            correct+=(y_pred==label).sum().item()
        print("\n Test Set: Average loss: xx , Accuracy:{}/{} ({:.0f}%)".format(
            correct,len(test_dataset),100.*correct/len(test_dataset)))
        print("="*50)

In [202]:
if __name__=='__main__':
    seed=42
    EPOCHS=10
    
    for epoch in range(1,EPOCHS+1):
        train(model,train_loader,optimizer,device,epoch)
        test(model,test_loader,device)


 Test Set: Average loss: xx , Accuracy:8960/10000 (90%)

 Test Set: Average loss: xx , Accuracy:9075/10000 (91%)

 Test Set: Average loss: xx , Accuracy:9180/10000 (92%)

 Test Set: Average loss: xx , Accuracy:8911/10000 (89%)

 Test Set: Average loss: xx , Accuracy:9031/10000 (90%)

 Test Set: Average loss: xx , Accuracy:9095/10000 (91%)

 Test Set: Average loss: xx , Accuracy:8799/10000 (88%)

 Test Set: Average loss: xx , Accuracy:9081/10000 (91%)

 Test Set: Average loss: xx , Accuracy:9099/10000 (91%)

 Test Set: Average loss: xx , Accuracy:9113/10000 (91%)


In [186]:
import gc
#del model
gc.collect()
torch.cuda.empty_cache()