In [28]:
!pip install tensorboardX
!pip install ml_collections
!pip install monai



In [29]:
!git clone --branch training https://github.com/lesnikowka/TRUNet.git


fatal: destination path 'TRUNet' already exists and is not an empty directory.


In [30]:
import logging
import os
import random
import numpy as np
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from statistics import fmean


def dice_score(pred, gt):  # data in shape [batch, classes, h, w, d]
    dice = []
    for batchloop in range(gt.shape[0]):
        dice_tmp = []
        for roi in range(gt.shape[1]):
            if roi > 0:  # skip background
                pred_tmp = pred[int(batchloop), int(roi)]
                gt_tmp = gt[int(batchloop), int(roi)]
                a = np.sum(pred_tmp[gt_tmp == 1])
                b = np.sum(pred_tmp)
                c = np.sum(gt_tmp)
                if a == 0:
                    metric = 0
                else:
                    metric_ = a * 2.0 / (b + c)
                    metric = metric_.item()
                dice_tmp.append(metric)
        dice.append(fmean(dice_tmp))
    return fmean(dice)


def one_hot_encoder(input_tensor, n_classes):
    tensor_list = []
    for i in range(n_classes):
        temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
        tensor_list.append(temp_prob)
    output_tensor = torch.cat(tensor_list, dim=1)
    return output_tensor.float()


def to_one_arr_encoding(input_tensor):  # input shape: [batch, channels, h, w, d]
    new_arr = torch.zeros(input_tensor.shape)
    for batchloop in range(input_tensor.shape[0]):
        for d in range(input_tensor.shape[1]):
            new_arr[batchloop, d] = torch.where(input_tensor[batchloop, d] == 1, d + 1, 0)
    return new_arr.sum(1).unsqueeze(1)


def trainer(args, config, model, savepath):
    # Initializations
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    model.to(device)

    # Parameters
    loss_function = config['loss_function']
    optimizer = config['optimizer']
    dataset_train = config['ds_train']
    dataset_val = config['ds_val']
    save_interval = config['save_interval']

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    # Data Loaders
    train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True,
                              worker_init_fn=worker_init_fn)
    val_loader = DataLoader(dataset_val, batch_size=1, shuffle=True, num_workers=2, pin_memory=True,
                            worker_init_fn=worker_init_fn)

    max_iterations = args.max_epochs * len(train_loader)

    # logging
    logging.basicConfig(filename=args.save_path + "/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.info(args)
    logging.info("{} iterations per epoch. {} max iterations ".format(len(train_loader), max_iterations))
    writer = SummaryWriter(savepath + '/log')

    best_metric = -1
    best_metric_epoch = -1
    metric_values = []
    iter_num = 0

    ############################
    #         Training         #
    ############################

    for epoch in range(args.max_epochs):
        epoch_loss = 0
        model.train()
        for i_batch, sampled_batch in enumerate(train_loader):
            # get inputs and targets
            inputs, targets = sampled_batch['image'], sampled_batch['label']
            # here the input and target have the shape [batch, H, L, D]
            # so we need to add the channel dimension
            inputs, targets = inputs.unsqueeze(1), targets.unsqueeze(1)
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, targets)

            loss.backward()
            optimizer.step()

            # update learning rate
            epoch_loss += loss
            lr_ = args.base_lr * (1.0 - iter_num / max_iterations) ** 0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_
            lrlog = lr_
            writer.add_scalar('info/lr', lr_, iter_num)
            iter_num = iter_num + 1

            # write to log
            writer.add_scalar('info/total_loss', loss, iter_num)
            # logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))

        epoch_loss = epoch_loss / len(train_loader)

        logging.info('epoch %d : mean loss : %f' % (epoch, epoch_loss))

        ############################
        #        Validation        #
        ############################

        model.eval()
        with torch.no_grad():
            dice_tmp = []
            for i_batch, sampled_batch in enumerate(val_loader):
                # get inputs and targets
                inputs, targets = sampled_batch['image'], sampled_batch['label']
                # targets needs to be transformed to one-hot encoding
                targets = targets.unsqueeze(1)
                targets = one_hot_encoder(targets, args.num_classes)
                inputs, targets = inputs.to(device), targets.to(device)

                val_outputs = model(inputs)
                m = nn.Softmax(dim=1)
                val_outputs = m(val_outputs)

                # compute metric for current iteration
                dice_tmp.append(dice_score(val_outputs.cpu().data.numpy(), targets.cpu().data.numpy()))

            # aggregate the final mean dice result
            metric = fmean(dice_tmp)

            # write to log
            writer.add_scalar('info/validation_metric', metric, epoch)
            logging.info('iteration %d : dice score : %f' % (epoch, metric))

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(args.save_path, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}. Current learning rate {lrlog}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )

        ############################
        #          Saving          #
        ############################

        # add an example to tensorboard logging
        labs = to_one_arr_encoding(targets)
        outputs = torch.argmax(torch.softmax(val_outputs, dim=1), dim=1, keepdim=True)

        if len(inputs.shape) == 5:
            image = inputs[:, :, :, :, round(args.img_size / 2)]
            labs = labs[:, :, :, :, round(args.img_size / 2)]
            outputs = outputs[:, :, :, round(args.img_size / 2)]
        else:
            image = inputs

        image = image[0]
        labs = torch.squeeze(labs * 50, 1)
        outputs = outputs[0] * 50
        image = (image - image.min()) / (image.max() - image.min())
        writer.add_image('train/Image', image, iter_num)
        writer.add_image('train/Prediction', outputs, iter_num)
        writer.add_image('train/GroundTruth', labs, iter_num)

        if (epoch + 1) % save_interval == 0:
            save_mode_path = os.path.join(args.save_path, 'epoch_' + str(epoch) + '.pth')
            torch.save(model.state_dict(), save_mode_path)
            logging.info("save model to {}".format(save_mode_path))

        if epoch >= args.max_epochs - 1:
            save_mode_path = os.path.join(args.save_path, 'epoch_' + str(epoch) + '.pth')
            torch.save(model.state_dict(), save_mode_path)
            logging.info("save model to {}".format(save_mode_path))
            break

    writer.close()
    return "Training Finished!"

In [31]:
import os
print(os.getcwd())
import sys
import os

# Добавляем путь к корневой папке проекта
sys.path.append('/content/TRUNet')



/content


In [34]:
import argparse
import ml_collections
import os
from glob import glob
import torch
import numpy as np
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from TRUNet.TRUNet_network.model.ViT import VisionTransformer3d as TransUNet3d
from datetime import datetime
from TRUNet_network.trunet_train import trainer
from torchvision import transforms
from TRUNet_network.augmentations import RandomGenerator3d_zoom, Reshape3d_zoom

now = datetime.now()

def TransUNet_configs(img_size):
    configs_trunet = ml_collections.ConfigDict()

    configs_trunet.resnet = ml_collections.ConfigDict()
    configs_trunet.resnet.num_layers = (3, 4, 9)
    configs_trunet.resnet.width_factor = 1
    configs_trunet.transformer_mlp_dim = 3072
    configs_trunet.transformer_num_heads = 12
    configs_trunet.transformer_num_layers = 12
    configs_trunet.transformer_attention_dropout_rate = 0.0
    configs_trunet.transformer_dropout_rate = 0.1
    configs_trunet.classifier = 'seg'
    configs_trunet.decoder_channels = (256, 128, 64, 16)
    configs_trunet.n_classes = 7
    configs_trunet.n_skip = 3
    configs_trunet.skip_channels = [512, 256, 64, 16]
    configs_trunet.patches = ml_collections.ConfigDict()
    configs_trunet.patches.grid = None

    configs_trunet.hidden_size = 768
    configs_trunet.patches.size = 16

    configs_trunet.patch_size = configs_trunet.patches.size  # (results in 14 by 14 grid of patches for input size 224)

    configs_trunet.patches.grid = (
        int(img_size / configs_trunet.patches.size), int(img_size / configs_trunet.patches.size),
        int(img_size / configs_trunet.patches.size))
    configs_trunet.hybrid = True

    return configs_trunet


class fetch_dataset:
    def __init__(self, base_dir, crop=None, transform=None):
        self.transform = transform
        self.data_dir = base_dir
        sample_list = sorted(glob(os.path.join(base_dir, '*.npz')))
        self.sample_list = sample_list
        self.name = [i.split('/')[-1].split('.npz')[0] for i in sample_list]
        self.pt = [int(i.split('_')[0][2:]) for i in self.name]
        self.crop = crop
        if crop is None or crop == 'None':
            pass
        else:
            with open(crop) as file:
                lines = [line.rstrip() for line in file]
            self.crop_pt = [int(i.split(' ')[0]) for i in lines]
            self.crop_xmin = [float(i.split(' ')[1]) for i in lines]
            self.crop_xmax = [float(i.split(' ')[2]) for i in lines]
            self.crop_ymin = [float(i.split(' ')[3]) for i in lines]
            self.crop_ymax = [float(i.split(' ')[4]) for i in lines]
            self.crop_zmin = [float(i.split(' ')[5]) for i in lines]
            self.crop_zmax = [float(i.split(' ')[6]) for i in lines]

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        data_path = self.sample_list[idx]
        data = np.load(data_path)
        image, label = data['arr_0'], data['arr_1']

        if self.crop is None or self.crop == 'None':
            pass
        else:
            idx_crop = self.crop_pt.index(self.pt[idx])
            idx_crop = int(idx_crop)
            x, y, z = image.shape
            xmin = max([0, int(self.crop_xmin[idx_crop]) - 20])
            ymin = max([0, int(self.crop_ymin[idx_crop]) - 20])
            zmin = max([0, int(self.crop_zmin[idx_crop]) - 20])
            xmax = min([x, int(self.crop_xmax[idx_crop]) + 20])
            ymax = min([y, int(self.crop_ymax[idx_crop]) + 20])
            zmax = min([z, int(self.crop_zmax[idx_crop]) + 20])
            image = image[xmin:xmax, ymin:ymax, zmin:zmax]
            label = label[xmin:xmax, ymin:ymax, zmin:zmax]

        sample = {'image': image, 'label': label, 'case_name': self.sample_list[idx].split('/')[-1].split('.npz')[0]}
        if self.transform:
            sample = self.transform(sample)

        return sample


if __name__ == "__main__":

    args = ml_collections.ConfigDict()
    args.max_epochs = args_.max_epochs
    args.save_path = args_.save_path
    args.root_path = args_.root_path
    args.crop = args_.crop
    args.num_classes = args_.num_classes
    args.batch_size = args_.batch_size
    args.seed = 42
    args.base_lr = 0.01

    # Transforms & Augmentations
    train_transforms = transforms.Compose(
        [RandomGenerator3d_zoom(output_size=(args.img_size, args.img_size, args.img_size))])
    val_transforms = transforms.Compose([Reshape3d_zoom(output_size=[args.img_size, args.img_size, args.img_size])])

    # Define model
    config_net = TransUNet_configs(args.img_size)
    model = TransUNet3d(config_net, img_size=args.img_size, num_classes=args.num_classes, zero_head=False,
                        vis=False)
    config = {'ds_val': fetch_dataset(base_dir=os.path.join(args.root_path, 'val'), transform=val_transforms,
                                      crop=args_.crop),
              'ds_train': fetch_dataset(base_dir=os.path.join(args.root_path, 'train'),
                                        transform=train_transforms, crop=args_.crop),
              'loss_function': DiceCELoss(include_background=False, to_onehot_y=True, softmax=True),
              'metric': DiceMetric(include_background=False, reduction="mean"),
              'optimizer': torch.optim.Adam(model.parameters(), args.base_lr),
              'save_interval': 50}

    if args_.checkpoint == 'None':
        pass
    else:
        print('loading checkpoint ', args_.checkpoint)
        model_state = torch.load(args_.checkpoint, map_location='cpu')
        model.load_state_dict(model_state)

    trainer(args, config, model, args.save_path)

'\nif __name__ == "__main__":\n\n    args = ml_collections.ConfigDict()\n    args.max_epochs = args_.max_epochs\n    args.save_path = args_.save_path\n    args.root_path = args_.root_path\n    args.crop = args_.crop\n    args.num_classes = args_.num_classes\n    args.batch_size = args_.batch_size\n    args.seed = 42\n    args.base_lr = 0.01\n\n    # Transforms & Augmentations\n    train_transforms = transforms.Compose(\n        [RandomGenerator3d_zoom(output_size=(args.img_size, args.img_size, args.img_size))])\n    val_transforms = transforms.Compose([Reshape3d_zoom(output_size=[args.img_size, args.img_size, args.img_size])])\n\n    # Define model\n    config_net = TransUNet_configs(args.img_size)\n    model = TransUNet3d(config_net, img_size=args.img_size, num_classes=args.num_classes, zero_head=False,\n                        vis=False)\n    config = {\'ds_val\': fetch_dataset(base_dir=os.path.join(args.root_path, \'val\'), transform=val_transforms,\n                                