In [17]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import h5py

import cv2
import albumentations as A
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2

import time 
from tqdm import tqdm
from glob import glob
import zipfile
import shutil


In [18]:
!git clone https://github.com/soniamartinot/MVA-Dose-Prediction.git

fatal: destination path 'MVA-Dose-Prediction' already exists and is not an empty directory.


In [19]:
# from google.colab import drive
# drive.mount('/content/drive')

In [20]:
transform_image = A.Compose([
    A.Normalize(
    mean=[97.6],
    std=[314.5],
    max_pixel_value=255),
    A.ElasticTransform(alpha=15, sigma=5, alpha_affine=5, p=0.5),
    A.Rotate(limit=10, p=0.5),
    A.GaussNoise(var_limit=(10, 50), p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
    A.Blur(blur_limit=(3, 7), p=0.5),
])

transform_mask = A.Compose([
    A.ElasticTransform(alpha=15, sigma=5, alpha_affine=5, p=0.5),
    A.Rotate(limit=10, p=0.5),
], additional_targets={'mask': 'mask'})

transform_val = A.Normalize(
                  mean=[97.6],
                  std=[314.5],
                  max_pixel_value=255)

train_transforms = {
    'image': transform_image,
    'mask': transform_mask
}

val_transforms = {
    'image': transform_val,
    'mask': None
}


In [21]:
class DoseDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.samples = os.listdir(data_path)
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_path = self.data_path + os.sep + self.samples[idx]
        ct_scan = np.load(sample_path + os.sep + 'ct.npy')
        possible_dose_mask = np.load(sample_path + os.sep + 'possible_dose_mask.npy')
        dose = torch.from_numpy(np.load(sample_path + os.sep + 'dose.npy')).float()
        structure_masks = np.load(sample_path + os.sep + 'structure_masks.npy')

        combined_masks = np.concatenate([possible_dose_mask[np.newaxis, :, :], structure_masks], axis=0)

        if self.transform:
          if self.transform == 'train_transforms':
            augmentations = self.transform(image=ct_scan, mask=combined_masks)
            ct_scan = augmentations['image']
            combined_masks = augmentations['mask']
          if self.transform == 'val_transforms':
            augmentations = self.transform(image=ct_scan)
            ct_scan = augmentations['image']

        input_image = torch.from_numpy(np.concatenate([ct_scan[np.newaxis, :, :], combined_masks], axis=0)).float()
        dose = dose.unsqueeze(0)

        return {'input_image': input_image, 'dose': dose}

In [22]:
class test_DoseDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.samples = os.listdir(data_path)
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_path = self.data_path + os.sep + self.samples[idx]
        ct_scan = torch.from_numpy(np.load(sample_path + os.sep + 'ct.npy')).float()
        possible_dose_mask = torch.from_numpy(np.load(sample_path + os.sep + 'possible_dose_mask.npy')).float()
        structure_masks = torch.from_numpy(np.load(sample_path + os.sep + 'structure_masks.npy')).float()

        input_image = np.concatenate([ct_scan[np.newaxis, :, :], possible_dose_mask[np.newaxis, :, :], structure_masks], axis=0)
        input_image = torch.tensor(input_image, dtype=torch.float32)
        
        return {'sample_name': self.samples[idx], 'input_image': input_image}

In [23]:
train_dir = "./MVA-Dose-Prediction/train/"

train_dataset = DoseDataset(train_dir, transform=train_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=32, num_workers=2, shuffle=True, pin_memory=True)
print(len(train_dataloader))

244


In [24]:
val_dir = "./MVA-Dose-Prediction/validation/"

val_dataset = DoseDataset(val_dir, transform=val_transforms)
val_dataloader = DataLoader(val_dataset, batch_size=32, num_workers=2, shuffle=True, pin_memory=True)
print(len(val_dataloader))

38


In [25]:
test_dir = "./MVA-Dose-Prediction/test"

test_dataset = test_DoseDataset(test_dir)
test_dataloader = DataLoader(test_dataset, batch_size=32, num_workers=2, shuffle=True, pin_memory=True)
print(len(test_dataloader))

38


In [26]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


In [27]:
class DilatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(DilatedConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)

    def forward(self, x):
        x = self.conv(x)
        return x


In [28]:
class DepthwiseSeparableDilatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(DepthwiseSeparableDilatedConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


In [29]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels: int, forward_expansion: int, out_channels: int):
        super(BasicBlock, self).__init__()
        """
        A convolutional block with batch normalization and ReLU activation.
        """
        self.conv1 = nn.Conv2d(in_channels, forward_expansion, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(forward_expansion)
        self.conv2 = nn.Conv2d(forward_expansion, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = nn.ReLU()(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = nn.ReLU()(out)
        return out

class BasicBlock_skip(nn.Module):
    def __init__(self, in_channels: int, forward_expansion: int, out_channels: int):
        super(BasicBlock_skip, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, forward_expansion, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(forward_expansion)
        self.conv2 = nn.Conv2d(forward_expansion, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.residual = nn.Identity()
        if in_channels != out_channels:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.residual(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = nn.ReLU()(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = nn.ReLU()(out)

        return out

class BasicBlock_DSDC(nn.Module):
    def __init__(self, in_channels: int, forward_expansion: int, out_channels: int, dilation=1):
        super(BasicBlock_DSDC, self).__init__()

        self.conv1 = DepthwiseSeparableDilatedConv(in_channels, forward_expansion, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
        self.bn1 = nn.BatchNorm2d(forward_expansion)
        self.conv2 = DepthwiseSeparableDilatedConv(forward_expansion, out_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.residual = nn.Identity()
        if in_channels != out_channels:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.residual(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = nn.ReLU()(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = nn.ReLU()(out)

        return out


class UNet(nn.Module):
    def __init__(self, in_channels=12, out_channels=1):
        super(UNet, self).__init__()

        # Downward pass
        self.down1 = nn.Sequential(
            BasicBlock_DSDC(in_channels, 16, 16),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down2 = nn.Sequential(
            BasicBlock_DSDC(16, 32, 32),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down3 = nn.Sequential(
            BasicBlock_DSDC(32, 64, 64),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down4 = nn.Sequential(
            BasicBlock_DSDC(64, 128, 128),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Bottom block
        self.bottom = BasicBlock_DSDC(128, 256, 128)

        # Upward pass
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(64 + 64, 32, kernel_size=2, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(32 + 32, 16, kernel_size=2, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(16 + 16, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Downward pass
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)

        # Bottom block
        bottom = self.bottom(x4)

        # Upward pass
        y4 = self.up4(bottom)
        y3 = self.up3(torch.cat((y4, x3), dim=1))
        y2 = self.up2(torch.cat((y3, x2), dim=1))
        y1 = self.up1(torch.cat((y2, x1), dim=1))

        return y1



In [30]:
class BigBasicBlock(nn.Module):
    def __init__(self, in_channels: int, forward_expansion: int, out_channels: int):
        super(BigBasicBlock, self).__init__()
 
        self.conv1 = nn.Conv2d(in_channels, forward_expansion, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(forward_expansion)
        self.conv2 = nn.Conv2d(forward_expansion, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = nn.ReLU()(out)
        out = nn.Dropout(0.1)(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = nn.ReLU()(out)
        out = nn.Dropout(0.1)(out)
        return out

class BigUNet(nn.Module):
    def __init__(self, in_channels=12, out_channels=1):
        super(BigUNet, self).__init__()

        # Downward pass

        self.first = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )

        self.down1 = nn.Sequential(
            BigBasicBlock(64, 64, 128),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.down2 = nn.Sequential(
            BigBasicBlock(128, 128, 256),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.down3 = nn.Sequential(
            BigBasicBlock(256, 256, 512),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # Bottom block
        self.bottom = BigBasicBlock(512, 512, 512)

        # Upward pass
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(512+512, 256, kernel_size=2, stride=2),
            BasicBlock_DSDC(256, 256, 256),
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(256 + 256, 128, kernel_size=2, stride=2),
            BasicBlock_DSDC(128, 128, 128),
        )
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(128 + 128, 64, kernel_size=2, stride=2),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
        )

        self.last = nn.Sequential(
            nn.Conv2d(32 + in_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(8, 4, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(4, 2, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(2, 1, kernel_size=3, stride=1, padding=1)  
        )

    def forward(self, x):
        # Downward pass
        x_initial = self.first(x)
        x1 = self.down1(x_initial)
        x2 = self.down2(x1)
        x3 = self.down3(x2)

        # Bottom block
        bottom = self.bottom(x3)

        # Upward pass
        y3 = self.up3(torch.cat((bottom, x3), dim=1))
        y2 = self.up2(torch.cat((y3, x2), dim=1))
        y1 = self.up1(torch.cat((y2, x1), dim=1))
        
        # Concatenate input layer
        y_concat = torch.cat((y1, x), dim=1)
        
        y = self.last(y_concat)
        y = nn.ReLU()(y)

        return y



In [31]:
# Create a training loop
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        inputs, labels = data['input_image'].to(device), data['dose'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / (i + 1)

# Create a validation loop
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data['input_image'].to(device), data['dose'].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
    return running_loss / (i + 1)


In [None]:
# Train the model and evaluate its performance on the validation set
num_epochs = 15
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = BigUNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)

train_loss_list = []
val_loss_list = []

for epoch in range(num_epochs):
    train_loss = train(model, train_dataloader, criterion, optimizer, device)
    val_loss = validate(model, val_dataloader, criterion, device)
    print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)

# Save the trained model
torch.save(model.state_dict(), 'trained_model.pth')

cuda
Epoch 1/15, Train Loss: 2.3017, Val Loss: 1.9964
Epoch 2/15, Train Loss: 1.7295, Val Loss: 1.3789
Epoch 3/15, Train Loss: 1.2363, Val Loss: 1.1254
Epoch 4/15, Train Loss: 1.1257, Val Loss: 1.1543
Epoch 5/15, Train Loss: 1.0723, Val Loss: 1.0881
Epoch 6/15, Train Loss: 1.0315, Val Loss: 0.9993
Epoch 7/15, Train Loss: 1.0032, Val Loss: 0.9625
Epoch 8/15, Train Loss: 0.9783, Val Loss: 0.9618
Epoch 9/15, Train Loss: 0.9542, Val Loss: 0.9322
Epoch 10/15, Train Loss: 0.9346, Val Loss: 0.9792


Epoch 1/15, Train Loss: 2.3127, Val Loss: 1.7891
Epoch 2/15, Train Loss: 1.3737, Val Loss: 1.1563
Epoch 3/15, Train Loss: 1.1337, Val Loss: 1.1161
Epoch 4/15, Train Loss: 1.0713, Val Loss: 1.0499
Epoch 5/15, Train Loss: 1.0433, Val Loss: 1.1620
Epoch 6/15, Train Loss: 1.0064, Val Loss: 0.9521
Epoch 7/15, Train Loss: 0.9911, Val Loss: 1.0015
Epoch 8/15, Train Loss: 0.9615, Val Loss: 1.1773
Epoch 9/15, Train Loss: 0.9416, Val Loss: 1.1358
Epoch 10/15, Train Loss: 0.9284, Val Loss: 0.9237
Epoch 11/15, Train Loss: 0.9069, Val Loss: 0.9929
Epoch 12/15, Train Loss: 0.8920, Val Loss: 0.8746
Epoch 13/15, Train Loss: 0.8807, Val Loss: 0.9296
Epoch 14/15, Train Loss: 0.8664, Val Loss: 0.9162
Epoch 15/15, Train Loss: 0.8530, Val Loss: 0.8629


In [None]:
plt.plot(range(num_epochs), train_loss_list, label='Training Loss')
plt.plot(range(num_epochs), val_loss_list, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show

In [None]:
plt.plot(range(num_epochs), train_loss_list, label='Training Loss')
plt.plot(range(num_epochs), val_loss_list, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show




BigUnet -> train: val: 0.9933


BigUnet with dropout -> train: val: 1.001



In [None]:
# load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
# Load the trained model
model.load_state_dict(torch.load('/content/drive/MyDrive/trained_model.pth'))

In [None]:
model.eval()

# Create a directory to store the predicted dose files
output_dir = "predictions"
os.makedirs(output_dir, exist_ok=True)

# Loop through the test dataset and make predictions
with torch.no_grad():
    for idx, data in enumerate(test_dataloader):
        sample_names = data['sample_name']
        inputs = data['input_image'].to(device)
        outputs = model(inputs)
        
        batch_outputs = outputs.cpu().numpy()
        
        # Save each sample's prediction separately
        for i, sample_name in enumerate(sample_names):
            output_file = os.path.join(output_dir, f'{sample_name}.npy')
            np.save(output_file, batch_outputs[i])

# Create a zip file with all the predicted dose files
with zipfile.ZipFile("submission.zip", "w") as zf:
    for file in os.listdir(output_dir):
        file_path = os.path.join(output_dir, file)
        zf.write(file_path, arcname=file)



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers, dropout_rate, l2):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.BatchNorm2d(in_channels + growth_rate * i),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels + growth_rate * i, growth_rate, 3, padding=1, bias=False),
                nn.Dropout(dropout_rate) if dropout_rate else nn.Identity()
            )
            for i in range(num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            out = layer(x)
            x = torch.cat([x, out], dim=1)
        return x

class TransitionDown(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate):
        super(TransitionDown, self).__init__()
        self.layers = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.Dropout(dropout_rate) if dropout_rate else nn.Identity(),
            nn.AvgPool2d(2)
        )

    def forward(self, x):
        return self.layers(x)

class DenseUNet(nn.Module):
    def __init__(self, input_shape, dropout_rate=None, l2=0.00000001, activation='relu', lr=0.0001, print_summary=False):
        super(DenseUNet, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(input_shape[2], 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )

        growth_rates = [8, 16, 32, 64]
        self.dense_blocks = nn.ModuleList([DenseBlock(32 + sum(growth_rates[:i]), growth_rate, 4, dropout_rate, l2) for i, growth_rate in enumerate(growth_rates)])
        self.transition_downs = nn.ModuleList([TransitionDown(32 + sum(growth_rates[:i + 1]), 32 + sum(growth_rates[:i + 1]), dropout_rate) for i in range(len(growth_rates) - 1)])

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(800, 512, 3, stride=2, padding=1, output_padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            DenseBlock(256, 64, 4, dropout_rate, l2),
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            DenseBlock(128, 32, 4, dropout_rate, l2),
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            DenseBlock(64, 16, 4, dropout_rate, l2),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            DenseBlock(32, 8, 4, dropout_rate, l2)
            )

        self.output = nn.Conv2d(80, 1, 1, bias=True)
            
        self.print_summary = print_summary

        def forward(self, x):
            x1 = self.encoder(x)
            x = x1

            for dense_block, transition_down in zip(self.dense_blocks[:-1], self.transition_downs):
                x = dense_block(x)
                x = transition_down(x)

            x = self.dense_blocks[-1](x)

            x = self.decoder(x)
            x = torch.cat([x, x1], dim=1)
            x = self.output(x)
            
            return x

