In [10]:
import torch
import PIL.Image as Image
from torchvision.utils import save_image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torchvision.transforms import Compose, CenterCrop, Normalize, Scale, Resize
from torchvision.transforms import ToTensor, ToPILImage
from torch.utils.data import DataLoader
# from model import *
import numpy as np
import argparse
import matplotlib.pyplot as plt

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16
import torch.optim as optim


In [11]:
#mian structure of our network
# compute region correlation 
class non_local_block(nn.Module):
    def __init__(self):
        super(non_local_block, self).__init__()

        self.mlp1 = nn.Linear(512, 4096)
        self.mlp2 = nn.Linear(4096, 512)
        self.theta = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0)
        self.phi = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0)
        self.conv1x1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0)
        self.pool = nn.AvgPool2d(kernel_size=4, stride=4)

        self.conv_y = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0)
        self.conv_lastlayer=nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=1,stride=1,padding=0)

        self.pool_y=nn.AdaptiveAvgPool2d((2,2))

    def forward(self, featureX,featureY):
        #加pool
        featureY=self.pool_y(featureY)

        batch_size = featureX.size(0)  # N
        channel_size = featureX.size(1)

        theta_x = self.theta(featureX).view(batch_size, channel_size // 2, -1)  # (batch,C//2,H*W)
        theta_x = theta_x.permute(0, 2, 1)  # (batch,H2*W2,C//2)

        phi_y1 = self.phi(featureY).view(batch_size, channel_size // 2, -1)  # (batch,C//2,7*7)
        f1 = torch.matmul(theta_x, phi_y1)  # (batch,H*W,7*7)
        f_div_C1 = F.softmax(f1, dim=-1)  # normalize the last dim by softmax
        featureY = featureY.view(batch_size, channel_size, -1)  # N,512,7*7
        featureY = featureY.permute(0, 2, 1)  # N,7*7,512
        y1 = torch.matmul(f_div_C1, featureY)  # batch,H*W,C
        y1 = y1.permute(0, 2, 1).contiguous()
        # y1 = y1.view(batch_size, channel_size, 16, 16).permute(0, 2, 3, 1)  # batch,16,16,512
        # # experiment12 对比下linear和1x1卷积
        # W_y1 = self.mlp2(F.tanh(self.mlp1(y1))).permute(0, 3, 1, 2)  # batch,512,16,16
        y1=y1.view(batch_size,channel_size,16,16)   # batch,512,16,16

        return y1

        # out=torch.cat((W_y1,featureX),dim=1)
        # out=F.relu(self.conv_lastlayer(out))
        # # out=W_y1*featureX
        #
        # return out


# Define some constants
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        #print(classname)
        m.weight.data.normal_(0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# LSTM to optimize
class ConvLSTMCell(nn.Module):
    """
    Generate a convolutional LSTM cell
    """

    def __init__(self, input_size, hidden_size,kernel_size=3,dilation=1):
        super().__init__()
        self.input_size = input_size#input channel
        self.hidden_size = hidden_size#hidden channel
        self.Gates = nn.Conv2d(in_channels=input_size + hidden_size,
                               out_channels= 4 * hidden_size,
                               kernel_size=kernel_size,
                               dilation=dilation,
                               padding=(kernel_size-3)//2+dilation)
        self.Gates.apply(weights_init)


    def forward(self, input_, prev_state):

        # get batch and spatial sizes
        batch_size = input_.data.size()[0]
        spatial_size = input_.data.size()[2:]

        # generate empty prev_state, if None is provided
        if prev_state is None:
            state_size = [batch_size, self.hidden_size] + list(spatial_size)
            prev_state = (
                torch.zeros(state_size),
                torch.zeros(state_size)
            )

        prev_hidden, prev_cell = prev_state#previous state

        # data size is [batch, channel, height, width]
        stacked_inputs = torch.cat((input_, prev_hidden), 1)
        gates = self.Gates(stacked_inputs)
        #print (gates.shape)

        # chunk across channel dimension
        in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)

        # apply sigmoid non linearity
        in_gate = torch.sigmoid(in_gate)
        remember_gate = torch.sigmoid(remember_gate)
        out_gate = torch.sigmoid(out_gate)

        # apply tanh non linearity
        cell_gate = torch.tanh(cell_gate)

        # compute current cell and hidden state
        cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
        hidden = out_gate * torch.tanh(cell)

        return hidden, cell

# our model
class model(nn.Module):
    def __init__(self):
        super(model,self).__init__()
        self.pretrained_model=vgg16(pretrained=True)
        self.features,self.classifiers=list(self.pretrained_model.features.children()),list(self.pretrained_model.classifier.children())

        self.features_map=nn.Sequential(*self.features)
        self.global_avg_pool=nn.AdaptiveAvgPool2d((1,1))
        self.mlp1=nn.Linear(512,4096)
        self.mlp2=nn.Linear(4096,512)
        self.upsample=nn.Upsample(16)
        self.dec = Decoder(2, 512, 2, activ='relu', pad_type='reflect')

        self.lstm_cell = ConvLSTMCell(512, 512, kernel_size=3, dilation=1)
        self.conv=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=1,stride=1)

        self.non_local_block1=non_local_block()
        self.conv_last=nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=1)
        self.iteration=4

    def forward(self, x,y):
        vgg_x,vgg_y,vgg_x_weight,vgg_y_weight = self.encode(x,y)
        images_recon_x,images_recon_y = self.decode(vgg_x_weight,vgg_y_weight)
        return images_recon_x,images_recon_y


    def encode(self, x,y):
        vgg_x = self.features_map(x)
        vgg_y = self.features_map(y)

        x_input=self.upsample(self.global_avg_pool(vgg_y))
        y_input=self.upsample(self.global_avg_pool(vgg_x))

        hidden_state_x = vgg_x
        cell_x = vgg_x

        hidden_state_y = vgg_y
        cell_y = vgg_y

        for i in range(self.iteration):
            hidden_state_x, cell_x = self.lstm_cell(x_input, (hidden_state_x, cell_x))
            hidden_state_y, cell_y = self.lstm_cell(y_input, (hidden_state_y, cell_y))

            non_local_x=self.non_local_block1(cell_x,cell_y)
            non_local_y=self.non_local_block1(cell_y,cell_x)

            x_input=self.upsample(self.global_avg_pool(cell_y))
            x_input=(x_input+non_local_x)/2
            y_input=self.upsample(self.global_avg_pool(cell_x))
            y_input=(y_input+non_local_y)/2

        # vgg_x_weight = self.global_avg_pool(vgg_x)
        # vgg_x_weight = self.upsample(F.softmax(self.mlp2(F.tanh(self.mlp1(vgg_x_weight.view(-1,512)))),dim=-1).view(-1,512,1,1))
        #
        # vgg_y_weight = self.global_avg_pool(vgg_y)
        # vgg_y_weight = self.upsample(F.softmax(self.mlp2(F.tanh(self.mlp1(vgg_y_weight.view(-1,512)))),dim=-1).view(-1,512,1,1))
        return vgg_x,vgg_y,hidden_state_x,hidden_state_y

    def decode(self,  vgg_x_weight, vgg_y_weight):
        vgg_x_weight=F.relu(self.conv(vgg_x_weight))
        vgg_y_weight=F.relu(self.conv(vgg_y_weight))

        images_x = self.dec(vgg_x_weight)
        images_y = self.dec(vgg_y_weight)
        return images_x,images_y

    def set_iteration(self,n):
        self.iteration=n



class LayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps

        if self.affine:
            self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        shape = [-1] + [1] * (x.dim() - 1)
        mean = x.view(x.size(0), -1).mean(1).view(*shape)
        std = x.view(x.size(0), -1).std(1).view(*shape)
        x = (x - mean) / (std + self.eps)

        if self.affine:
            shape = [1, -1] + [1] * (x.dim() - 2)
            x = x * self.gamma.view(*shape) + self.beta.view(*shape)
        return x

class ResBlocks(nn.Module):
    def __init__(self, num_blocks, dim, norm='bn', activation='relu', pad_type='zero'):
        super(ResBlocks, self).__init__()
        self.model = []
        for i in range(num_blocks):
            self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
        self.model = nn.Sequential(*self.model)

    def forward(self, x):
        return self.model(x)

# predict masks

class Decoder(nn.Module):
    def __init__(self, n_res, dim, output_dim, activ='relu', pad_type='zero'):
        super(Decoder, self).__init__()

        self.model = []
        # AdaIN residual blocks
        self.model += [ResBlocks(n_res, dim, 'bn', activ, pad_type=pad_type)]

        self.model += [nn.Upsample(scale_factor=2),
                           Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='bn', activation=activ, pad_type='reflect')]
        dim //= 2
        self.model += [nn.Upsample(scale_factor=2),
                           Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='bn', activation=activ, pad_type='reflect')]
        dim //= 2
        self.model += [nn.Upsample(scale_factor=2),
                           Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='bn', activation=activ, pad_type='reflect'),
                           ]
        dim //= 2
        self.model += [nn.Upsample(scale_factor=2),
                           Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='bn', activation=activ, pad_type='reflect'),
                           ]
        dim //= 2
        self.model += [nn.Upsample(scale_factor=2),
                           Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='bn', activation=activ, pad_type='reflect'),
			   ]
        dim //= 2
        # use reflection padding in the last conv layer
        self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='bn', activation='none', pad_type='reflect')]
        self.model = nn.Sequential(*self.model)

    def forward(self, x):
        return self.model(x)



##################################################################################
# Basic Blocks
##################################################################################
class ResBlock(nn.Module):
    def __init__(self, dim, norm='ln', activation='relu', pad_type='zero'):
        super(ResBlock, self).__init__()

        model = []
        model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
        model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        residual = x
        out = self.model(x)
        out += residual
        return out

class Conv2dBlock(nn.Module):
    def __init__(self, input_dim ,output_dim, kernel_size, stride,
                 padding=0, norm='none', activation='relu', pad_type='zero'):
        super(Conv2dBlock, self).__init__()
        self.use_bias = True
        # initialize padding
        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)

        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # initialize normalization
        norm_dim = output_dim
        if norm == 'bn':
            self.norm = nn.BatchNorm2d(norm_dim)
        elif norm == 'in':
            self.norm = nn.InstanceNorm2d(norm_dim)
        elif norm == 'ln':
            self.norm = LayerNorm(norm_dim)
        elif norm == 'adain':
            self.norm = AdaptiveInstanceNorm2d(norm_dim)
        elif norm == 'none':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # initialize activation
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'softmax':
            self.activation = nn.Softmax(dim=-1)
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # initialize convolution
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)

    def forward(self, x):
        x = self.conv(self.pad(x))
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x




In [12]:
#data load and preprocess
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset
# import skimage.io as io
import glob
import numpy as np
import random


def get_images(filename):
    image_names = np.genfromtxt(filename, dtype=str)
    return image_names


def load_image(file):
    return Image.open(file)


class coseg_train_dataset(Dataset):
    def __init__(self, data_dir, label_dir, traintxt, input_transform=None, label_transform=None):
        self.data_dir = data_dir
        self.label_dir = label_dir
        self.input_transform = input_transform
        self.label_transform = label_transform
        self.traintxt = traintxt
        self.train_names = get_images(self.traintxt)

    def __getitem__(self, index):

        imagename1 = self.data_dir + self.train_names[index][0] + ".jpg"
        imagename2 = self.data_dir + self.train_names[index][1] + ".jpg"
        labelname1 = self.label_dir + self.train_names[index][2] + ".png"
        labelname2 = self.label_dir + self.train_names[index][3] + ".png"

        with open(imagename1, "rb") as f:
            image1 = load_image(f).convert('RGB')
        with open(imagename2, "rb") as f:
            image2 = load_image(f).convert('RGB')

        with open(labelname1, "rb") as f:
            label1 = load_image(f).convert('L')
        with open(labelname2, "rb") as f:
            label2 = load_image(f).convert('L')

        # random horizontal flip
        if random.random() < 0.5:
            image1 = image1.transpose(Image.FLIP_LEFT_RIGHT)
            label1 = label1.transpose(Image.FLIP_LEFT_RIGHT)
            # image1 = image1.transpose(Image.ROTATE_90)
            # label1=label1.transpose(Image.ROTATE_90)

        # random horizontal flip
        if random.random() < 0.5:
            image2 = image2.transpose(Image.FLIP_LEFT_RIGHT)
            label2 = label2.transpose(Image.FLIP_LEFT_RIGHT)
            # image2 = image2.transpose(Image.ROTATE_90)
            # label2=label2.transpose(Image.ROTATE_90)

        if self.input_transform is not None:
            image1 = self.input_transform(image1)
            image2 = self.input_transform(image2)

        if self.label_transform is not None:
            label1 = self.label_transform(label1)
            label2 = self.label_transform(label2)

        return image1, image2, label1, label2

    def __len__(self):
        return len(self.train_names)


class coseg_val_dataset(Dataset):
    def __init__(self, data_dir, label_dir, val_txt, input_transform=None, label_transform=None):
        self.data_dir = data_dir
        self.label_dir = label_dir
        self.input_transform = input_transform
        self.label_transform = label_transform
        self.val_txt = val_txt
        self.val_names = get_images(self.val_txt)

    def __getitem__(self, index):

        imagename1 = self.data_dir + self.val_names[index][0] + ".jpg"
        imagename2 = self.data_dir + self.val_names[index][1] + ".jpg"
        labelname1 = self.label_dir + self.val_names[index][2] + ".png"
        labelname2 = self.label_dir + self.val_names[index][3] + ".png"

        with open(imagename1, "rb") as f:
            image1 = load_image(f).convert('RGB')
        with open(imagename2, "rb") as f:
            image2 = load_image(f).convert('RGB')

        with open(labelname1, "rb") as f:
            label1 = load_image(f).convert('L')
        with open(labelname2, "rb") as f:
            label2 = load_image(f).convert('L')

        if self.input_transform is not None:
            image1 = self.input_transform(image1)
            image2 = self.input_transform(image2)

        if self.label_transform is not None:
            label1 = self.label_transform(label1)
            label2 = self.label_transform(label2)

        return image1, image2, label1, label2

    def __len__(self):
        return len(self.val_names)





In [18]:
#train setting and process
#you may need to change the files' path below
train_data_dir="/home/guankai/PascalVocCoseg/image/"
train_label_dir="/home/guankai/PascalVocCoseg/colabel/train/"
train_txt="/home/guankai/PascalVocCoseg/colabel/train.txt"
val_data_dir="/home/guankai/PascalVocCoseg/image/"
val_label_dir="/home/guankai/PascalVocCoseg/colabel/val/"
val_txt="/home/guankai/PascalVocCoseg/colabel/val1600.txt"

# print(os.path.abspath('/home/guankai/PascalVocCoseg/colabel/train.txt'))

parser = argparse.ArgumentParser()
parser.add_argument("--verbosity", help="increase output verbosity")


parser = argparse.ArgumentParser(description='Attention Based Co-segmentation')
parser.add_argument('--lr', default=1e-5, type=float, help='learning rate')
parser.add_argument('--weight_decay', default=0.0005,
                    help='weight decay value')
parser.add_argument('--gpu_ids', default=[0], help='a list of gpus')
parser.add_argument('--num_worker', default= 8, help='numbers of worker')
parser.add_argument('--batch_size', default=4, help='bacth size')
parser.add_argument('--epoches', default=2, help='epoches')

parser.add_argument('--train_data', help='training data directory',default=train_data_dir)
parser.add_argument('--val_data', help='validation data directory',default=val_data_dir)
parser.add_argument('--train_txt', help='training image pair names txt',default=train_txt)
parser.add_argument('--val_txt', help='validation image pair names txt',default=val_txt)
parser.add_argument('--train_label', help='training label directory',default=train_label_dir)
parser.add_argument('--val_label', help='validation label directory',default=val_label_dir)
parser.add_argument('--model_path', help='model saving directory',default='model_path/')

args = parser.parse_args(args=[])

# let the label pixels =1 if it >0
class Relabel:
    def __call__(self, tensor):
        assert isinstance(
            tensor, torch.LongTensor), 'tensor needs to be LongTensor'
        tensor[tensor > 0] = 1
        return tensor

# numpy -> tensor
class ToLabel:
    def __call__(self, image):
        return torch.from_numpy(np.array(image)).long()



class Trainer:
    def __init__(self):
        self.args = args
        self.input_transform = Compose([Resize((512, 512)), ToTensor(), Normalize([.485, .456, .406], [.229, .224, .225])])
        self.label_transform = Compose([Resize((512, 512)), CenterCrop(512), ToLabel(), Relabel()])

        self.net = model().cuda()
        #self.net.turn_off(self.net)
        self.net = nn.DataParallel(self.net, device_ids=self.args.gpu_ids)
        # self.net.cuda()
        self.train_data_loader = DataLoader(coseg_train_dataset(self.args.train_data, self.args.train_label, self.args.train_txt, self.input_transform, self.label_transform),
                                            num_workers=self.args.num_worker, batch_size=self.args.batch_size, shuffle=True)
        self.val_data_loader = DataLoader(coseg_val_dataset(self.args.val_data, self.args.val_label, self.args.val_txt, self.input_transform, self.label_transform),
                                          num_workers=self.args.num_worker, batch_size=self.args.batch_size, shuffle=False)
        self.params=filter(lambda p:p.requires_grad,self.net.parameters())
        #self.optimizer = optim.Adam(filter(lambda p:p.requires_grad,self.net.parameters()), lr=self.args.lr, weight_decay=self.args.weight_decay)
        #weight_decay   权重衰减 --正则化方法之一
        self.optimizer=optim.Adam(self.net.parameters(),lr=self.args.lr,weight_decay=self.args.weight_decay)
        self.loss_func = nn.CrossEntropyLoss()

    def pixel_accuracy(self, output, label):
        correct = len(output[output == label])
        wrong = len(output[output != label])
        return correct, wrong

    def jaccard(self, output, label):
        # print('output',output)
        # print('label',label)
        # ss
        temp = output[label == 1]
        i = len(temp[temp == 1])
        temp = output + label
        u = len(temp[temp > 0])
        return i, u

    def precision(self, output, label):
        temp = output[label == 1]
        tp = len(temp[temp == 1])
        p = len(output[output > 0])
        return tp, p

    def evaluate(self, net, epoch):
        print('--eval')
        self.net.eval()
        correct = 0
        wrong = 0
        intersection = 0
        union = 0
        true_positive = 0
        positive = 1
        for i, (image1, image2, label1, label2) in enumerate(self.val_data_loader):
            image1, image2, label1, label2 = image1.cuda(
            ), image2.cuda(), label1.cuda(), label2.cuda()
            with torch.no_grad():

                output1, output2 = self.net(image1, image2)
            output1 = torch.argmax(output1, dim=1)
            output2 = torch.argmax(output2, dim=1)

            # eval output1
            c, w = self.pixel_accuracy(output1, label1)
            correct += c
            wrong += w

            i, u = self.jaccard(output1, label1)
            intersection += i
            union += u

            tp, p = self.precision(output1, label1)
            true_positive += tp
            positive += p
            # eval output2
            c, w = self.pixel_accuracy(output2, label2)
            correct += c
            wrong += w

            i, u = self.jaccard(output2, label2)
            intersection += i
            union += u

            tp, p = self.precision(output2, label2)
            true_positive += tp
            positive += p

        print("pixel accuracy: {} correct: {}  wrong: {}".format(
            correct / (correct + wrong), correct, wrong))
        print("precision: {} true_positive: {} positive: {}".format(
            true_positive / positive, true_positive, positive))
        print("jaccard score: {} intersection: {} union: {}".format(
            intersection / union, intersection, union))
        self.net.train()
        return correct / (correct + wrong), intersection / union, true_positive / positive

    def train(self):
        precison_list=[]
        jaccard_list=[]
        for epoch in range(self.args.epoches):
            losses = []
            for i, (image1, image2, label1, label2) in enumerate(self.train_data_loader):
                image1, image2, label1, label2 = image1.cuda(
                ), image2.cuda(), label1.cuda(), label2.cuda()
                #print('image1',image1.shape)#[N, 3, 512, 512]


                output1, output2 = self.net(image1, image2)


                #print('outpu1',output1.shape)#[N, 2, 512, 512]
                # calculate loss from output1 and output2
                loss = self.loss_func(output1, label1)
                loss += self.loss_func(output2, label2)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                losses.append(loss.data.cpu().numpy())

                if i % 2000 == 0:
                    print("---------------------------------------------")
                    print("epoch{} iter {}/{} BCE loss:".format(epoch,
                                                                i, len(self.train_data_loader), np.mean(losses)))
                    print("testing......")
                    acc, jac, pre = self.evaluate(self.net, epoch)
                    precison_list.append(pre)
                    jaccard_list.append(jac)
                    plot_precisonAndjac('', precison_list,jaccard_list)


                if i % 2000 == 0 and i != 0:
                    torch.save(self.net.state_dict(),
                               'epoch{}iter{}.pkl'.format(epoch, i))


def plot_precisonAndjac(checkpoint_dir,pre_list,jac_list):
    x=range(0,len(pre_list))
    y=pre_list
    y2=jac_list
    plt.switch_backend('agg')
    plt.plot(x,y,color='red',marker='o',label='precision')
    plt.plot(x,y2,color='blue',marker='o',label='jaccard')
    plt.xticks(range(0,len(pre_list)+3,(len(pre_list)+10)//10))
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(checkpoint_dir,'precisionAndjac_fig.pdf'))
    plt.close()


trainer = Trainer()
trainer.train()

RuntimeError: Assertion `THCTensor_(checkGPU)(state, 4, input, target, output, total_weight)' failed. Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one. at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:65

In [17]:

# test phase 
ckpt = ''
test_image_1 = ''
test_image_2 = ''

class Demo:
    def __init__(self):
        # self.args = args
        self.net = model().cuda()
        self.net = nn.DataParallel(self.net, device_ids=[0])
        # self.net.load_state_dict(torch.load(ckpt))

        self.input_transform = Compose([Resize((512, 512)), ToTensor(
        ), Normalize([.485, .456, .406], [.229, .224, .225])])
        # self.image1_path = self.args.image1
        # self.image2_path = self.args.image2


    def pixel_accuracy(self,output,label):
        correct=len(output[output==label])
        wrong=len(output[output!=label])
        return correct,wrong

    def jaccard(self,output,label):

        temp=output[label==1]
        intersection=len(temp[temp==1])
        temp=output+label
        union=len(temp[temp>0])
        return intersection,union

    def precision(self,output,label):
        temp=output[label==1]
        tp=len(temp[temp==1])
        p=len(output[output>0])
        return tp,p

    def single_demo(self):
        self.net.eval()
        image1 = Image.open(open(test_image_1)).convert('RGB')#Image.fromarray(np.zeros([512,512,3],dtype=np.uint8))#
        image2 = Image.open(open(test_image_2)).convert('RGB')#Image.fromarray(np.zeros([512,512,3],dtype=np.uint8))#
        image1 = self.input_transform(image1)
        image2 = self.input_transform(image2)
        image1, image2 = image1.unsqueeze(0).cuda(), image2.unsqueeze(0).cuda()

        output1, output2 = self.net(image1, image2)
        print('output_origin',output1.shape)

        output1=torch.softmax(output1,dim=1,)[:,1,:,:]
        output2=torch.softmax(output2,dim=1)[:,1,:,:]

        print('output',output1.shape)#[1,512,512]

        image1 = (image1 - image1.min()) / image1.max()
        image2 = (image2 - image2.min()) / image2.max()
        output1=torch.cat([output1,output1,output1],dim=0).unsqueeze(0)
        output2=torch.cat([output2,output2,output2],dim=0).unsqueeze(0)
        output1 = output1.float().data * 0.8 + image2.data
        output2 = output2.float().data * 0.8 + image2.data

        print(output1.shape)
        print(output2.shape)
        npimg = output1[0].cpu().numpy()
        # plt.imshow(np.transpose(npimg, (1, 2, 0)))
        npimg = np.hstack([npimg,output2[0].cpu().numpy()])
        print(npimg.shape)
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
demo = Demo()
demo.single_demo()

print("Finish!!!")


FileNotFoundError: ignored