In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
%cd "/content/"
%mkdir data
%cd data
!unzip "/content/gdrive/My Drive/mlmi/input.zip"
!unzip "/content/gdrive/My Drive/mlmi/masks.zip"
%ls

In [0]:
!pip install numpy
!pip install torch
!pip install matplotlib
!pip install scipy
!pip install torchvision
!pip install tqdm
!pip install visdom
!pip install nibabel
!pip install scikit-image
!pip install h5py
!pip install pandas
!pip install dominate
!pip install pydicom
!pip install opencv-python
!pip install scikit-learn
!pip install https://github.com/ozan-oktay/torchsample/tarball/master#egg=torchsample-0.1.3

In [0]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pydicom
import numpy as np
import pandas as pd
import os
import torch.optim as optim
from tqdm import tqdm

In [0]:
def init_weights(net, init_type='normal'):
    #print('initialization method [%s]' % init_type)
    if init_type == 'normal':
        net.apply(weights_init_normal)
    elif init_type == 'xavier':
        net.apply(weights_init_xavier)
    elif init_type == 'kaiming':
        net.apply(weights_init_kaiming)
    elif init_type == 'orthogonal':
        net.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)

def weights_init_kaiming(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        nn.init.normal(m.weight.data, 1.0, 0.02)
        nn.init.constant(m.bias.data, 0.0)

In [0]:
class unetConv2(nn.Module):
    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n+1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.BatchNorm2d(out_size),
                                     nn.ReLU(inplace=True),)
                setattr(self, 'conv%d'%i, conv)
                in_size = out_size

        else:
            for i in range(1, n+1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.ReLU(inplace=True),)
                setattr(self, 'conv%d'%i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')

    def forward(self, inputs):
        x = inputs
        for i in range(1, self.n+1):
            conv = getattr(self, 'conv%d'%i)
            x = conv(x)

        return x

class unetUp2(nn.Module):
    def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True):
        super(unetUp2, self).__init__()
        self.conv = unetConv2(in_size + out_size, out_size, is_batchnorm)
        if is_deconv:
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('unetConv2') != -1: continue
            init_weights(m, init_type='kaiming')

    def forward(self, inputs1, inputs2):
        outputs2 = self.up(inputs2)
        # offset = outputs2.size()[1] - inputs1.size()[1]
        # padding = [0, 0, 0, 0, offset // 2, offset // 2]
        # outputs1 = F.pad(inputs1, padding)
        outputs1 = inputs1
        cat = torch.cat([outputs1, outputs2], 1)
        return self.conv(cat)

class UnetGridGatingSignal2(nn.Module):
    def __init__(self, in_size, out_size, ks=1, is_batchnorm=True):
        super(UnetGridGatingSignal2, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, ks),
                                       nn.BatchNorm2d(out_size),
                                       nn.ReLU(inplace=True),
                                       )
        else:
            self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, ks),
                                       nn.ReLU(inplace=True),
                                       )

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        return outputs

class UnetDsv2(nn.Module):
    def __init__(self, in_size, out_size, scale_factor):
        super(UnetDsv2, self).__init__()
        self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0),
                                 nn.Upsample(scale_factor=scale_factor, mode='bilinear'), )

    def forward(self, input):
        return self.dsv(input)

class _GridAttentionBlockND(nn.Module):
    def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation',
                 sub_sample_factor=(2,2,2)):
        super(_GridAttentionBlockND, self).__init__()

        assert dimension in [2, 3]
        assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual']

        # Downsampling rate for the input featuremap
        if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor
        elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor)
        else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension

        # Default parameter set
        self.mode = mode
        self.dimension = dimension
        self.sub_sample_kernel_size = self.sub_sample_factor

        # Number of channels (pixel dimensions)
        self.in_channels = in_channels
        self.gating_channels = gating_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            bn = nn.BatchNorm3d
            self.upsample_mode = 'trilinear'
        elif dimension == 2:
            conv_nd = nn.Conv2d
            bn = nn.BatchNorm2d
            self.upsample_mode = 'bilinear'
        else:
            raise NotImplemented

        # Output transform
        self.W = nn.Sequential(
            conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),
            bn(self.in_channels),
        )

        # Theta^T * x_ij + Phi^T * gating_signal + bias
        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)
        self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0, bias=True)
        self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)

        # Initialise weights
        for m in self.children():
            init_weights(m, init_type='kaiming')

        # Define the operation
        if mode == 'concatenation':
            self.operation_function = self._concatenation
        elif mode == 'concatenation_debug':
            self.operation_function = self._concatenation_debug
        elif mode == 'concatenation_residual':
            self.operation_function = self._concatenation_residual
        else:
            raise NotImplementedError('Unknown operation function.')


    def forward(self, x, g):
        '''
        :param x: (b, c, t, h, w)
        :param g: (b, g_d)
        :return:
        '''

        output = self.operation_function(x, g)
        return output

    def _concatenation(self, x, g):
        input_size = x.size()
        batch_size = input_size[0]
        assert batch_size == g.size(0)

        # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
        # phi   => (b, g_d) -> (b, i_c)
        theta_x = self.theta(x)
        theta_x_size = theta_x.size()

        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
        #  Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
        phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
        f = F.relu(theta_x + phi_g, inplace=True)

        #  psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
        sigm_psi_f = torch.sigmoid(self.psi(f))

        # upsample the attentions and multiply
        sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
        y = sigm_psi_f.expand_as(x) * x
        W_y = self.W(y)

        return W_y, sigm_psi_f

class GridAttentionBlock2D(_GridAttentionBlockND):
    def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
                 sub_sample_factor=(2,2,2)):
        super(GridAttentionBlock2D, self).__init__(in_channels,
                                                   inter_channels=inter_channels,
                                                   gating_channels=gating_channels,
                                                   dimension=2, mode=mode,
                                                   sub_sample_factor=sub_sample_factor
                                                   )


class MultiAttentionBlock(nn.Module):
    def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor):
        super(MultiAttentionBlock, self).__init__()
        self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size,
                                                 inter_channels=inter_size, mode=nonlocal_mode,
                                                 sub_sample_factor= sub_sample_factor)
        self.combine_gates = nn.Sequential(nn.Conv2d(in_size, in_size, kernel_size=1, stride=1, padding=0),
                                           nn.BatchNorm2d(in_size),
                                           nn.ReLU(inplace=True)
                                           )

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue
            init_weights(m, init_type='kaiming')

    def forward(self, input, gating_signal):
        gate_1, attention_1 = self.gate_block_1(input, gating_signal)
        return self.combine_gates(gate_1), attention_1

In [0]:
class unet_simm(nn.Module):

    def __init__(self, feature_scale, n_classes, is_deconv, in_channels,
                 nonlocal_mode, attention_dsample, is_batchnorm):
        super(unet_simm, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale
        # self.use_cuda = Config.use_cuda

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]
        # filter [16, 32, 64, 128, 256]

        # downsampling
        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm, ks=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm, ks=3)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm, ks=3)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm, ks=3)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.center = unetConv2(filters[3], filters[4], self.is_batchnorm, ks=3)
        self.gating = UnetGridGatingSignal2(filters[4], filters[4], ks=1, is_batchnorm=self.is_batchnorm)

        # attention blocks
        self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
        self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
        self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)

        # upsampling
        self.up_concat4 = unetUp2(filters[4], filters[3], False, is_batchnorm)
        # self.up_concat4 = unetUp2(512, filters[3], False, is_batchnorm)
        self.up_concat3 = unetUp2(filters[3], filters[2], False, is_batchnorm)
        self.up_concat2 = unetUp2(filters[2], filters[1], False, is_batchnorm)
        self.up_concat1 = unetUp2(filters[1], filters[0], False, is_batchnorm)

        # deep supervision
        self.dsv4 = UnetDsv2(in_size=filters[3], out_size=n_classes, scale_factor=8)
        self.dsv3 = UnetDsv2(in_size=filters[2], out_size=n_classes, scale_factor=4)
        self.dsv2 = UnetDsv2(in_size=filters[1], out_size=n_classes, scale_factor=2)
        self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=n_classes, kernel_size=1)

        # final conv (without any concat)
        self.final = nn.Conv2d(n_classes*4, n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')

    def forward(self, inputs):

        # if self.use_cuda == False:
        #     # inputs = inputs.to(dtype=torch.float64)
        #     inputs = inputs.double()

        # inputs = inputs.double()

        # Feature Extraction
        conv1 = self.conv1(inputs)
        maxpool1 = self.maxpool1(conv1)

        conv2 = self.conv2(maxpool1)
        maxpool2 = self.maxpool2(conv2)

        conv3 = self.conv3(maxpool2)
        maxpool3 = self.maxpool3(conv3)

        conv4 = self.conv4(maxpool3)
        maxpool4 = self.maxpool4(conv4)

        # Gating Signal Generation
        center = self.center(maxpool4)
        gating = self.gating(center)

        # Attention Mechanism
        # Upscaling Part (Decoder)
        g_conv4, att4 = self.attentionblock4(conv4, gating)
        up4 = self.up_concat4(g_conv4, center)
        g_conv3, att3 = self.attentionblock3(conv3, up4)
        up3 = self.up_concat3(g_conv3, up4)
        g_conv2, att2 = self.attentionblock2(conv2, up3)
        up2 = self.up_concat2(g_conv2, up3)
        up1 = self.up_concat1(conv1, up2)

        # Deep Supervision
        dsv4 = self.dsv4(up4)
        dsv3 = self.dsv3(up3)
        dsv2 = self.dsv2(up2)
        dsv1 = self.dsv1(up1)
        final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1))

        final = torch.sigmoid(final)

        return final

In [0]:
def rle2mask(rle, width, height):
    mask= np.zeros(width* height)
    array = np.asarray([int(x) for x in rle.split()])
    starts = array[0::2]
    lengths = array[1::2]

    current_position = 0
    for index, start in enumerate(starts):
        current_position += start
        mask[current_position:current_position+lengths[index]] = 255
        current_position += lengths[index]

    return mask.reshape(width, height)
    
class SIMMDataset(Dataset):
    """SIMM dataset."""

    def __init__(self, root_dir, dir_postfix, split, transform=None, preload_data=False):
        """
        Args:
            dicomPaths (Array<string>): Array of DICOM file Paths.
            mask_csv_file (string): csv file with encoded masks (rle).
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """

        self.root_dir = root_dir
        self.dir_postfix = dir_postfix

        self.im_height = 1024
        self.im_width = 1024
        self.im_chan = 1

        ## Read masks file
        mask_csv_file = root_dir + '/train-rle.csv' 
        self.encodedMasks = pd.read_csv(mask_csv_file, names=['ImageId', 'EncodedPixels'], index_col='ImageId')

        ## Read dataset file names
        dsFile = root_dir + '/simm_DS_' + split + '.csv'
        dsFileData = pd.read_csv(dsFile)
        self.dicomPaths = dsFileData['path'].tolist()

        self.transform = transform

    def __len__(self):
        return len(self.dicomPaths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        dPath = self.dicomPaths[idx]
        dicom = pydicom.dcmread(self.root_dir + self.dir_postfix + dPath)
        
#         image = np.zeros((1, im_height, im_width, im_chan), dtype=np.uint8)
#         image = np.expand_dims(dicom.pixel_array, axis=2)
        image = np.array(dicom.pixel_array)
        
        # get mask (in rle) from csv
        landmarks = np.zeros((self.im_height, self.im_width), dtype=np.bool)
        
        fileId = dPath.split('/')[-1][:-4]
        rle = self.encodedMasks.loc[fileId, 'EncodedPixels']
        try:
            if type(rle) == str: # if single rle
                decodedRle = rle2mask(rle, self.im_height, self.im_width)
#                 landmarks = np.expand_dims(decodedRle, axis=2)
                landmarks = decodedRle
            else: # if multiple rle
                for x in rle:
                    decodedRle = rle2mask(x, self.im_height, self.im_width)
                    landmarks = landmarks + decodedRle
#                     landmarks = landmarks + np.expand_dims(decodedRle, axis=2)
        except Exception as e:
            print(e)
            
        #### TODO - IMPORTANT::: CHECK THIS  
        ## QUESTION: SHOULD WE TRANSPOSE THE MASK IN THE GETITEM FUNCTION 
        ## BECAUSE WHEN PLOTING THE GRAPHS WE HAVE TO TRANSPOSE IT.
        landmarks = landmarks.T

        # for some images, we have multiple masks, so we are adding the masks
        # which results in some pixels to > 1
        landmarks = (landmarks >= 1).astype('float64')
            
        sample = {'image': image, 'mask': landmarks}

        if self.transform:
            sample = self.transform(sample)

        img = np.expand_dims(sample['image'], axis=0)
        mk = np.expand_dims(sample['mask'], axis=0)

        img = img / 255

        return img, mk

In [0]:
class SIMMSoftDiceLoss(nn.Module):
    def __init__(self):
        super(SIMMSoftDiceLoss, self).__init__()

    def forward(self, input, target):
        smooth = 0.01
        batch_size = input.size(0)

        # input = torch.sigmoid(input).view(batch_size, -1)
        input = input.view(batch_size, -1)
        target = target.contiguous().view(batch_size, -1)

        inter = torch.sum(input * target, 1) + smooth
        union = torch.sum(input, 1) + torch.sum(target, 1) + smooth

        score = torch.sum(2.0 * inter / union, 0)
        score = 1.0 - score / float(batch_size)
        
        return score

In [0]:
json_opts = {
  "use_cuda": True,
  "network_debug": 0,
  "training": {
    "arch_type": "simm_unet",
    "n_epochs": 100,
    "save_epoch_freq": 10,
    "lr_policy": "step",
    "lr_decay_iters": 25,
    "batchSize": 4,
    "preloadData": True,
    "network_debug": 0
  },
  "visualisation":{
    "display_port": 8097,
    "display_server": "http://54.89.248.230",
    "no_html": True,
    "display_winsize": 256,
    "display_id": 1,
    "display_single_pane_ncols": 0
  },
  "data_path": {
    "simm_unet": "/content/data",
    "postfix": "/input/siim/"
  },
  "augmentation": {
  },
  "model":{
    "type":"seg",
    "continue_train": False,
    "which_epoch": -1,
    "model_type": "unet_simm",
    "tensor_dim": "2D",
    "division_factor": 16,
    "input_nc": 1,
    "output_nc": 1,
    "lr_rate": 1e-4,
    "l2_reg_weight": 1e-6,
    "feature_scale": 4,
    "gpu_ids": [],
    "isTrain": True,
    "checkpoints_dir": "./checkpoints",
    "experiment_name": "experiment_unet_simm",
    "criterion": "SIMMSoftDiceLoss",
    "optim": "adam",
    "n_classes": 1
  }
}

import json
import collections
def json_file_to_pyobj(jsonStr):
    def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
    def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
    return json2obj(jsonStr)

json_opts = json_file_to_pyobj(json.dumps(json_opts))
train_opts = json_opts.training
model_opts = json_opts.model

In [0]:
numWorkers = 2
# Setup Dataset and Augmentation
ds_class = SIMMDataset
ds_path = json_opts.data_path.simm_unet
ds_postfix = json_opts.data_path.postfix

# ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation)

train_dataset = ds_class(ds_path, ds_postfix, split='train',      transform=None, preload_data=train_opts.preloadData)
valid_dataset = ds_class(ds_path, ds_postfix, split='validation', transform=None, preload_data=train_opts.preloadData)
# test_dataset  = ds_class(ds_path, split='test',       transform=None, preload_data=train_opts.preloadData)

train_loader = DataLoader(dataset=train_dataset, num_workers=numWorkers, batch_size=train_opts.batchSize, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, num_workers=numWorkers, batch_size=train_opts.batchSize, shuffle=False)
# test_loader  = DataLoader(dataset=test_dataset,  num_workers=numWorkers, batch_size=train_opts.batchSize, shuffle=False)

In [11]:
if torch.cuda.is_available():
  gpu_count = torch.cuda.device_count()
  print("Available GPU count:" + str(gpu_count))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Available GPU count:1
cuda:0


In [12]:
# Setup the NN Model
# (self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3,
#                  nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True)

model = unet_simm(n_classes=1,
                      is_batchnorm=True,
                      in_channels=1,
                      nonlocal_mode='concatenation',
                      feature_scale=4,
                      attention_dsample=(2,2),
                      is_deconv=False)
model = model.to(device)



In [0]:
# Visualisation Parameters
# visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)

In [13]:
# if Config.use_cuda:
    # torch.cuda.empty_cache()

criterion = SIMMSoftDiceLoss()
optimizer = optim.Adam(model.parameters(),
                        lr=model_opts.lr_rate,
                        betas=(0.9, 0.999),
                        weight_decay=model_opts.l2_reg_weight)

# optimizer = optim.SGD(params,
#                               lr=option.lr_rate,
#                               momentum=0.9,
#                               nesterov=True,
#                               weight_decay=option.l2_reg_weight)


# Training Function
# model.set_scheduler(train_opts)
for epoch in range(model_opts.which_epoch, train_opts.n_epochs):
    print('############# Running epoch: %d...\n' % (epoch))

    # Training Iterations
    running_loss = 0.0
    for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1), total=len(train_loader)):
        # Make a training update
        inputs = images.float().to(device)
        masks = labels.to(device)
        # assert input.size() == target.size()
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if epoch_iter % 20 == 0:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, epoch_iter + 1, running_loss / 20))
            running_loss = 0.0

    # Update the model learning rate
    # model.update_learning_rate()

############# Running epoch: -1...



  "See the documentation of nn.Upsample for details.".format(mode))
  5%|▍         | 20/401 [00:18<05:50,  1.09it/s]

[0,    21] loss: 0.964


 10%|▉         | 40/401 [00:37<05:41,  1.06it/s]

[0,    41] loss: 0.942


 15%|█▍        | 60/401 [00:56<05:30,  1.03it/s]

[0,    61] loss: 0.921


 20%|█▉        | 80/401 [01:16<05:17,  1.01it/s]

[0,    81] loss: 0.924


 25%|██▍       | 100/401 [01:36<04:54,  1.02it/s]

[0,   101] loss: 0.920


 30%|██▉       | 120/401 [01:55<04:31,  1.04it/s]

[0,   121] loss: 0.910


 35%|███▍      | 140/401 [02:14<04:11,  1.04it/s]

[0,   141] loss: 0.887


 40%|███▉      | 160/401 [02:34<03:54,  1.03it/s]

[0,   161] loss: 0.894


 41%|████▏     | 166/401 [02:40<03:48,  1.03it/s]

KeyboardInterrupt: ignored