In [1]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import InterpolationMode
import time
import os
import numpy as np
import pickle

In [2]:
resnet = torchvision.models.segmentation.deeplabv3_resnet50(weights=torchvision.models.segmentation.DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1)

In [3]:
resnet.classifier[4] = torch.nn.Conv2d(256,3,kernel_size=(1,1), stride=(1,1))
resnet.aux_classifier[4]  = torch.nn.Conv2d(256,3,kernel_size=(1,1), stride=(1,1))
resnet.classifier[0].convs[4][2].eval()

BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [4]:
my_gpu = torch.device(0)
resnet.to(my_gpu)

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [5]:
#img_transform = torchvision.transforms.Compose([torchvision.transforms.Resize((1024,1024)), torchvision.transforms.ToTensor()])
#img_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=meanSet.tolist(),std=stdSet.tolist())])
img_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Resize(500),torchvision.transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])
tar_transform = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(),torchvision.transforms.Resize(500,interpolation=InterpolationMode.NEAREST),torch.squeeze,torch.Tensor.long])
trainset = datasets.oxford_iiit_pet.OxfordIIITPet('G:/datasets/',transform=img_transform,target_types="segmentation",target_transform=tar_transform,binary=True)
testset = datasets.oxford_iiit_pet.OxfordIIITPet('G:/datasets/',split='test',transform=img_transform,target_types="segmentation",target_transform=tar_transform,binary=True)
trainloader = DataLoader(trainset,batch_size=1,shuffle=True)
testloader = DataLoader(testset,batch_size=1)
optim = torch.optim.SGD(resnet.parameters(),lr=1e-4,momentum=0.9)
crloss = torch.nn.CrossEntropyLoss(ignore_index=3)

In [6]:
loss_total = 0
step_total = 0
epochs = 1
for epoch in range(epochs):
    for img,target in trainloader:
        img = img.to(my_gpu)
        target = torchvision.transforms.functional.resize(target,500,interpolation=torchvision.transforms.InterpolationMode.NEAREST)
        target = torch.Tensor(target).to(my_gpu)
        optim.zero_grad()
        output = resnet(img)
        l1 = crloss(output['out'],target)
        l2 = crloss(output['aux'],target)
        loss = l1 + l2
        loss.backward()
        optim.step()
        loss.detach()
        torch.cuda.empty_cache()
        ld = loss.item()
        loss_total += ld
        step_total += 1
        if step_total % 100 == 0 and step_total != 0:
            print(loss_total/step_total)
            loss_total = 0
            step_total = 0
    #resnet.eval()
    #tStart = time.time()
    #samples = 0
    #correct = 0
    #for img,target in trainloader:
    #    img = img.to(my_gpu)
    #    target = torch.Tensor(target).to(my_gpu)
    #    output = resnet(img)
    #    samples += 1
    #    if target in output.topk(1).indices:
    #        correct += 1
    print('=======Epoch '+str(epoch)+'========')
    #print('% correct')
    #print(correct/samples)
    #print('======================')
    #resnet.train()

1.8796075761318207
1.4579942589998245
1.2815598165988922
1.0870730179548262
1.1059481471776962
0.9582273182272911
0.9885029336810112
0.9024903565645218
0.8528660282492637
1.006384706199169
0.8530811947584153
0.9323316261172294
0.8028081175684929
0.8668209785223007
0.8689862886071205
0.8015311384201049
0.8291705712676049
0.8263963171839714
0.7561054632067681
0.8046444880962372
0.7924031391739845
0.7595190452039242
0.7658590537309646
0.747298055589199
0.8269829317927361
0.7542937391996384
0.7474831688404083
0.7593463319540024
0.6813653001189232
0.7180224038660526
0.785240788012743
0.7367433550953865
0.675978296995163
0.6875821271538735
0.6585609801113605
0.6756694750487804


In [None]:
torch.save(resnet.state_dict(),'pets2seg.pth')

In [None]:
resnet.eval()
tStart = time.time()
samples = 0
correct = 0
for img,target in testloader:
    img = img.to(my_gpu)
    target = torch.Tensor(target).to(my_gpu)
    output = resnet(img)
    break
#    samples += 1
#    if target in output.topk(1).indices:
#        correct += 1
#tEnd = time.time()
#print(correct/samples)
#print(tEnd - tStart)
#print(len(testloader)/(tEnd-tStart))

In [None]:
img = img.detach().cpu()
output = output['out'].detach().cpu()

In [None]:
imgn = img.numpy()

In [None]:
imgn = imgn.squeeze()

In [None]:
imgn = np.moveaxis(imgn,0,-1)

In [None]:
imgn -= np.min(imgn)
imgn /= np.max(imgn)

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig = plt.figure(figsize=(10,10))
plt.imshow(imgn)
plt.show()

In [None]:
output = output.numpy()

In [None]:
output = output.squeeze()

In [None]:
output.shape

In [None]:
pxClass = np.argmax(output,axis=0)

In [None]:
fig = plt.figure(figsize=(10,10))
plt.imshow(pxClass)
plt.show()

In [None]:
imgn[pxClass==0,0] = 0
imgn[pxClass==0,2] = 0
imgn[pxClass==1,1] = 0
imgn[pxClass==1,2] = 0

In [None]:
fig = plt.figure(figsize=(10,10))
plt.imshow(imgn)
plt.show()

In [None]:
targetn = target.detach().cpu().numpy()

In [None]:
targetn = targetn.squeeze()

In [None]:
fig = plt.figure(figsize=(10,10))
plt.imshow(targetn)
plt.show()