In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

import os,sys

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from PIL import Image

In [2]:
class UNet(nn.Module):
    def __init__(self,in_channels=1,out_channels=1,init_features=64):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet.DoubleConv2d(in_channels,features)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.encoder2 = UNet.DoubleConv2d(features,2*features)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.encoder3 = UNet.DoubleConv2d(2*features,4*features)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.encoder4 = UNet.DoubleConv2d(4*features,8*features)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.bottleneck = UNet.DoubleConv2d(8*features,16*features)
        
        self.upconv4 = nn.ConvTranspose2d(16*features,8*features,kernel_size=2,stride=2)        
        self.decoder4 = UNet.DoubleConv2d(16*features,8*features) #concate, 2*8=16
        
        self.upconv3 = nn.ConvTranspose2d(8*features,4*features,kernel_size=2,stride=2)
        self.decoder3 = UNet.DoubleConv2d(8*features,4*features) #concate, 2*4=8
        
        self.upconv2 = nn.ConvTranspose2d(4*features,2*features,kernel_size=2,stride=2)
        self.decoder2 = UNet.DoubleConv2d(4*features,2*features) #concate, 2*2=4
        
        
        self.upconv1 = nn.ConvTranspose2d(2*features,features,kernel_size=2,stride=2)
        self.decoder1 = UNet.DoubleConv2d(2*features,features) #concate, 2*1=2
        
        self.conv_out = nn.Conv2d(features, 1, 1)
        
    def forward(self, input):
        enc1 = self.encoder1(input)
        
        enc2 = self.encoder2(self.pool1(enc1))
        
        enc3 = self.encoder3(self.pool2(enc2))
        
        enc4 = self.encoder4(self.pool3(enc3))
        
        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = torch.cat([enc4,self.upconv4(bottleneck)],dim=1)
        dec4 = self.decoder4(dec4)
        
        dec3 = torch.cat([enc3,self.upconv3(dec4)],dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = torch.cat([enc2,self.upconv2(dec3)],dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = torch.cat([enc1,self.upconv1(dec2)],dim=1)
        dec1 = self.decoder1(dec1)
        
        output = torch.sigmoid(self.conv_out(dec1))
        
        return output
    
    def DoubleConv2d(in_channels,features):
        return nn.Sequential(
                nn.Conv2d(in_channels, features, kernel_size=3, padding=1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True),
                nn.Conv2d(features, features, kernel_size=3, padding=1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True)
        )

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.img_conv1 = nn.Conv2d(1, 16, kernel_size = 3, padding=1)
        self.img_conv2 = nn.Conv2d(16, 64, kernel_size = 3, padding=1)
        self.img_pool = nn.MaxPool2d(2, 2)
        
        self.mask_conv = nn.Conv2d(1, 64, kernel_size = 3, padding=1)
        self.mask_pool = nn.MaxPool2d(2, 2)
        
        self.dis_conv1 = nn.Conv2d(128, 256, kernel_size = 3, padding=1)
        self.dis_pool1 = nn.MaxPool2d(2, 2)
        self.dis_conv2 = nn.Conv2d(256, 512, kernel_size = 3, padding=1)
        self.dis_pool2 = nn.MaxPool2d(2, 2)
        
        self.output = nn.Linear(512,1)
        
        
    def forward(self, input, mask):
        x = self.img_conv1(input)
        x = self.img_conv2(x)
        x = self.img_pool(x)
        
        y = self.mask_conv(mask)
        y = self.mask_pool(y)
        
        x = torch.cat([x,y],dim=1)
        x = self.dis_conv1(x)
        x = self.dis_pool1(x)
        x = self.dis_conv2(x)
        x = self.dis_pool2(x)
        x = F.max_pool2d(x,kernel_size=x.size()[2:]).squeeze()

        output = torch.sigmoid(self.output(x))
        
        return output

In [4]:
class UltraSoundDataSet(Dataset):
    def __init__(self, root_dir, transforms):
        self.root_dir = root_dir
        self.sample_list = os.listdir(root_dir)

        self.transform_image, self.transform_label = transforms
        
    def __len__(self):
        return len(self.sample_list)
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image_path = os.path.join(self.root_dir,self.sample_list[idx],"image.png")
        label_path = os.path.join(self.root_dir,self.sample_list[idx],"label.png")
        
        image = Image.open(image_path)
        label = Image.open(label_path)
        
        if self.transform_image is not None:
            image = self.transform_image(image)
            
        if self.transform_label is not None:
            label = self.transform_label(label)
        
        return image,label
        #sample = {"image":image,"label":label}
        
        #return sample
        

In [5]:
def DiceLoss(pred,target,slack=10):
    
    index = (2*torch.sum(pred*target)+slack)/(torch.sum(pred)+torch.sum(target)+slack)
    #if torch.sum(target).item() == 0:
    #print("instersection: ",torch.sum(pred*target).item())
    #print("pred: ",torch.sum(pred).item())
    #print("target: ",torch.sum(target).item())
    #print("Index: ", index.item())
    return 1-index

In [5]:
def train(net,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer = torch.optim.Adam(net.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=10) #mae: dice-index
    
    running_loss = 0
    step = 0
    for epoch in range(epochs):
        net.train()
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            pred = net(imgs)
            
            loss = DiceLoss(pred,labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            step += 1
            
            if step % 10 == 9:    # print every 10 mini-batches
                print('[%d, %5d] loss: %.3f' %(epoch + 1, step + 1, running_loss / 10))
                running_loss = 0.0
                
            if step%50 == 49:
                val_loss = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred = net(imgs)
                    val_loss += DiceLoss(pred,labels)
                print('[%d, %5d] validation loss: %.3f' %(epoch + 1, step + 1, val_loss / len(valid_loader)))
                scheduler.step(val_loss)
            
        
    

In [6]:
def train_adversarial(net,discriminator,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer = torch.optim.Adam(net.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=10) #mae: dice-index
    
    BCELoss = nn.BCELoss()
    
    running_loss = 0
    step = 0
    for epoch in range(epochs):
        net.train()
        discriminator.train()
        
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            pred = net(imgs)
            #print("grad: ",pred.grad)
            valid = torch.Tensor(imgs.size(0), 1).fill_(1.0)
            fake = torch.Tensor(imgs.size(0), 1).fill_(0.0)
            if(torch.rand(1)>0.5):
                #print("img shape: ",imgs.shape)
                #print("label shape: ",labels.shape)
                #label_disc = torch.Tensor([1]).to(device=device,dtype=torch.float32)
                label_disc = valid.to(device=device,dtype=torch.float32)
                pred_disc = discriminator(imgs, labels)
                #print(pred_disc)
                #print("grad: ",pred_disc.grad)
            else:
                #label_disc = torch.Tensor([0]).to(device=device,dtype=torch.float32)
                label_disc = fake.to(device=device,dtype=torch.float32)
                #print("img shape: ",imgs.shape)
                #print("pred shape: ",pred.shape)
                pred_disc = discriminator(imgs, pred.detach())
                #print(pred_disc)
                #print("grad: ",pred_disc.grad)
            
            
            loss = DiceLoss(pred,labels)

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            loss = BCELoss(pred_disc,label_disc)
            #loss = torch.norm(pred_disc-label_disc)

            optimizer.zero_grad()
            loss.backward(retain_graph=False)
            optimizer.step()
            
            running_loss += loss.item()
            step += 1
            
            if step % 10 == 9:    # print every 10 mini-batches
                print('[%d, %5d] loss: %.3f' %(epoch + 1, step + 1, running_loss / 10))
                running_loss = 0.0
                
            if step%50 == 49:
                val_loss = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred = net(imgs)
                    val_loss += DiceLoss(pred,labels)
                print('[%d, %5d] validation loss: %.3f' %(epoch + 1, step + 1, val_loss / len(valid_loader)))
                scheduler.step(val_loss)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(init_features=32)
discriminator = Discriminator()

net = net.to(device)
discriminator = discriminator.to(device)

root_dir = "/home/zhenyuli/workspace/us_robot/DataSet/SimDataset2"
try:
    #train(net=net,device=device,resize_to=None,epochs=50,batch_size=5)
    train_adversarial(net=net,discriminator=discriminator,device=device,resize_to=None,epochs=20,batch_size=2)
except KeyboardInterrupt:
    sys.exit()

[1,    10] loss: 0.628
[1,    20] loss: 0.685
[1,    30] loss: 0.718
[1,    40] loss: 0.705
[1,    50] loss: 0.706
[1,    50] validation loss: 0.821
[1,    60] loss: 0.697
[1,    70] loss: 0.688
[1,    80] loss: 0.681
[1,    90] loss: 0.706
[1,   100] loss: 0.689
[1,   100] validation loss: 0.616
[1,   110] loss: 0.697
[1,   120] loss: 0.700
[1,   130] loss: 0.697
[1,   140] loss: 0.693
[1,   150] loss: 0.705
[1,   150] validation loss: 0.404
[1,   160] loss: 0.694
[1,   170] loss: 0.689
[1,   180] loss: 0.680
[1,   190] loss: 0.700
[1,   200] loss: 0.684
[1,   200] validation loss: 0.296
[1,   210] loss: 0.705
[1,   220] loss: 0.705
[1,   230] loss: 0.696
[1,   240] loss: 0.700
[1,   250] loss: 0.691
[1,   250] validation loss: 0.256
[1,   260] loss: 0.677
[1,   270] loss: 0.700
[1,   280] loss: 0.693
[1,   290] loss: 0.705
[1,   300] loss: 0.703
[1,   300] validation loss: 0.237
[1,   310] loss: 0.681
[1,   320] loss: 0.693
[1,   330] loss: 0.682
[1,   340] loss: 0.700
[1,   350] los

[4,  2760] loss: 0.681
[4,  2770] loss: 0.689
[4,  2780] loss: 0.692
[4,  2790] loss: 0.684
[4,  2800] loss: 0.681
[4,  2800] validation loss: 0.889
[4,  2810] loss: 0.684
[4,  2820] loss: 0.695
[4,  2830] loss: 0.684
[4,  2840] loss: 0.696
[4,  2850] loss: 0.693
[4,  2850] validation loss: 0.879
[4,  2860] loss: 0.691
[4,  2870] loss: 0.694
[4,  2880] loss: 0.691
[4,  2890] loss: 0.683
[4,  2900] loss: 0.692
[4,  2900] validation loss: 0.876
[4,  2910] loss: 0.686
[4,  2920] loss: 0.681
[4,  2930] loss: 0.700
[4,  2940] loss: 0.684
[4,  2950] loss: 0.685
[4,  2950] validation loss: 0.869
[4,  2960] loss: 0.680
[4,  2970] loss: 0.683
[4,  2980] loss: 0.695
[4,  2990] loss: 0.692
[4,  3000] loss: 0.699
[4,  3000] validation loss: 0.866
[4,  3010] loss: 0.691
[4,  3020] loss: 0.681
[4,  3030] loss: 0.687
[4,  3040] loss: 0.683
[4,  3050] loss: 0.696
[4,  3050] validation loss: 0.839
[4,  3060] loss: 0.692
[4,  3070] loss: 0.693
[4,  3080] loss: 0.696
[4,  3090] loss: 0.687
[4,  3100] los

[7,  5510] loss: 0.694
[7,  5520] loss: 0.688
[7,  5530] loss: 0.696
[7,  5540] loss: 0.692
[7,  5550] loss: 0.684
[7,  5550] validation loss: 0.866
[7,  5560] loss: 0.689
[7,  5570] loss: 0.682
[7,  5580] loss: 0.696
[7,  5590] loss: 0.676
[7,  5600] loss: 0.691
[7,  5600] validation loss: 0.879
[7,  5610] loss: 0.692
[7,  5620] loss: 0.680
[7,  5630] loss: 0.686
[7,  5640] loss: 0.686
[7,  5650] loss: 0.691
[7,  5650] validation loss: 0.859
[7,  5660] loss: 0.690
[7,  5670] loss: 0.685
[7,  5680] loss: 0.686
[7,  5690] loss: 0.690
[7,  5700] loss: 0.691
[7,  5700] validation loss: 0.856
[7,  5710] loss: 0.687
[7,  5720] loss: 0.694
[7,  5730] loss: 0.700
[7,  5740] loss: 0.691
[7,  5750] loss: 0.684
[7,  5750] validation loss: 0.889
[7,  5760] loss: 0.695
[7,  5770] loss: 0.688
[7,  5780] loss: 0.682
[7,  5790] loss: 0.682
[7,  5800] loss: 0.690
[7,  5800] validation loss: 0.846
[7,  5810] loss: 0.690
[7,  5820] loss: 0.687
[7,  5830] loss: 0.697
[7,  5840] loss: 0.685
[7,  5850] los

[10,  8260] loss: 0.686
[10,  8270] loss: 0.689
[10,  8280] loss: 0.699
[10,  8290] loss: 0.704
[10,  8300] loss: 0.690
[10,  8300] validation loss: 0.826
[10,  8310] loss: 0.688
[10,  8320] loss: 0.688
[10,  8330] loss: 0.674
[10,  8340] loss: 0.686
[10,  8350] loss: 0.690
[10,  8350] validation loss: 0.859
[10,  8360] loss: 0.685
[10,  8370] loss: 0.687
[10,  8380] loss: 0.690
[10,  8390] loss: 0.686
[10,  8400] loss: 0.691
[10,  8400] validation loss: 0.859
[10,  8410] loss: 0.698
[10,  8420] loss: 0.691
[10,  8430] loss: 0.693
[10,  8440] loss: 0.681
[10,  8450] loss: 0.693
[10,  8450] validation loss: 0.856
[10,  8460] loss: 0.693
[10,  8470] loss: 0.687
[10,  8480] loss: 0.693
[10,  8490] loss: 0.683
[10,  8500] loss: 0.676
[10,  8500] validation loss: 0.916
[10,  8510] loss: 0.688
[10,  8520] loss: 0.693
[10,  8530] loss: 0.683
[10,  8540] loss: 0.682
[10,  8550] loss: 0.689
[10,  8550] validation loss: 0.849
[10,  8560] loss: 0.696
[10,  8570] loss: 0.690
[10,  8580] loss: 0.68

[13, 10910] loss: 0.692
[13, 10920] loss: 0.690
[13, 10930] loss: 0.686
[13, 10940] loss: 0.682
[13, 10950] loss: 0.687
[13, 10950] validation loss: 0.856
[13, 10960] loss: 0.691
[13, 10970] loss: 0.696
[13, 10980] loss: 0.689
[13, 10990] loss: 0.682
[13, 11000] loss: 0.688
[13, 11000] validation loss: 0.849
[13, 11010] loss: 0.693
[13, 11020] loss: 0.694
[13, 11030] loss: 0.685
[13, 11040] loss: 0.681
[13, 11050] loss: 0.690
[13, 11050] validation loss: 0.889
[13, 11060] loss: 0.682
[13, 11070] loss: 0.696
[13, 11080] loss: 0.690
[13, 11090] loss: 0.696
[13, 11100] loss: 0.680
[13, 11100] validation loss: 0.886
[13, 11110] loss: 0.683
[13, 11120] loss: 0.684
[13, 11130] loss: 0.685
[13, 11140] loss: 0.688
[13, 11150] loss: 0.685
[13, 11150] validation loss: 0.866
[13, 11160] loss: 0.676
[13, 11170] loss: 0.686
[13, 11180] loss: 0.689
[13, 11190] loss: 0.688
[13, 11200] loss: 0.680
[13, 11200] validation loss: 0.896
[13, 11210] loss: 0.681
[13, 11220] loss: 0.699
[13, 11230] loss: 0.68

[16, 13560] loss: 0.689
[16, 13570] loss: 0.688
[16, 13580] loss: 0.680
[16, 13590] loss: 0.681
[16, 13600] loss: 0.690
[16, 13600] validation loss: 0.839
[16, 13610] loss: 0.681
[16, 13620] loss: 0.700
[16, 13630] loss: 0.690
[16, 13640] loss: 0.687
[16, 13650] loss: 0.683
[16, 13650] validation loss: 0.859
[16, 13660] loss: 0.676
[16, 13670] loss: 0.689
[16, 13680] loss: 0.680
[16, 13690] loss: 0.696
[16, 13700] loss: 0.684
[16, 13700] validation loss: 0.846
[16, 13710] loss: 0.693
[16, 13720] loss: 0.690
[16, 13730] loss: 0.683
[16, 13740] loss: 0.696
[16, 13750] loss: 0.683
[16, 13750] validation loss: 0.876
[16, 13760] loss: 0.685
[16, 13770] loss: 0.682
[16, 13780] loss: 0.683
[16, 13790] loss: 0.687
[16, 13800] loss: 0.695
[16, 13800] validation loss: 0.859
[16, 13810] loss: 0.691
[16, 13820] loss: 0.695
[16, 13830] loss: 0.691
[16, 13840] loss: 0.690
[16, 13850] loss: 0.694
[16, 13850] validation loss: 0.868
[16, 13860] loss: 0.680
[16, 13870] loss: 0.676
[16, 13880] loss: 0.67

[19, 16210] loss: 0.696
[19, 16220] loss: 0.683
[19, 16230] loss: 0.693
[19, 16240] loss: 0.691
[19, 16250] loss: 0.691
[19, 16250] validation loss: 0.889
[19, 16260] loss: 0.685
[19, 16270] loss: 0.680
[19, 16280] loss: 0.699
[19, 16290] loss: 0.692
[19, 16300] loss: 0.686
[19, 16300] validation loss: 0.878
[19, 16310] loss: 0.694
[19, 16320] loss: 0.686
[19, 16330] loss: 0.690
[19, 16340] loss: 0.675
[19, 16350] loss: 0.696
[19, 16350] validation loss: 0.807
[19, 16360] loss: 0.699
[19, 16370] loss: 0.696
[19, 16380] loss: 0.686
[19, 16390] loss: 0.685
[19, 16400] loss: 0.691
[19, 16400] validation loss: 0.866
[19, 16410] loss: 0.693
[19, 16420] loss: 0.682
[19, 16430] loss: 0.688
[19, 16440] loss: 0.674
[19, 16450] loss: 0.699
[19, 16450] validation loss: 0.839
[19, 16460] loss: 0.678
[19, 16470] loss: 0.699
[19, 16480] loss: 0.693
[19, 16490] loss: 0.689
[19, 16500] loss: 0.687
[19, 16500] validation loss: 0.898
[19, 16510] loss: 0.693
[19, 16520] loss: 0.681
[19, 16530] loss: 0.69

In [9]:
test_dir = "/home/zhenyuli/workspace/us_robot/DataSet/SimDatasetTest2"
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])
for sample in testset_list:
    image_path = os.path.join(test_dir,sample,"image.png")
    label_path = os.path.join(test_dir,sample,"label.png")
    
    img = Image.open(image_path)
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    label = Image.open(label_path)
    label = transform_label(label).to(device)
    label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    fname = "pred%.2f.png"%DiceIndex
    sav_path = os.path.join(test_dir,sample,fname)
    pred.save(sav_path)
    


In [9]:
#for python
torch.save(net.state_dict(), './unet_usseg.pth')

#for c++
traced_script_module = torch.jit.trace(net, img)
traced_script_module.save("./unet_usseg_traced.pt")

In [3]:
test_dir = "/home/zhenyuli/workspace/us_robot/simulator/SimDatasetTest2"
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])

In [18]:
sample = '0002'

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_path = os.path.join(test_dir,sample,"image.png")
label_path = os.path.join(test_dir,sample,"label.png")
pred_path = os.path.join(test_dir,sample,"pred0.06.png")

img = Image.open(image_path)
img = transform_image(img)
img = img.to(device)
img = img.unsqueeze(0)
label = Image.open(label_path)
label = transform_label(label).to(device)
label = label.unsqueeze(0)
pred = Image.open(pred_path)
pred = transform_label(pred).to(device)
pred = pred.unsqueeze(0)

#with torch.no_grad():
#    pred = net(img)

DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()


In [19]:
pred_path = os.path.join(test_dir,sample,"pred0.06.png")
pred = Image.open(pred_path)
pred = transform_label(pred)

In [2]:
loss = nn.BCELoss()

In [5]:
x = torch.Tensor([0.1])
y = torch.Tensor([1])

In [6]:
loss(x,y)

tensor(2.3026)

In [28]:
x = torch.rand([1,2,2])

In [31]:
x.size()[2:]

torch.Size([2])

In [32]:
x.size()

torch.Size([1, 2, 2])