In [5]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
from pathlib import Path
import torch
from monai.transforms import (
    Compose, NormalizeIntensityd, AsDiscreted, LoadImaged, EnsureChannelFirstd,ConcatItemsd,DeleteItemsd,
    Resized, Rand3DElasticd, RandFlipd, RandRotate90d, Rand3DElasticd,RandBiasFieldd,SpatialPadd,
    RandShiftIntensityd, EnsureTyped, Activationsd, EnsureTyped, KeepLargestConnectedComponentd
)
from monai.data import Dataset, DataLoader

import matplotlib.pyplot as plt

In [2]:
dataset_path = 'E:\\Codigos\\Python\\Brain Stroke Segmentation\\ISLE2022'

In [3]:
dwi_dir = Path(dataset_path)/'DWIs'
flair_dir = Path(dataset_path)/'FLAIRs'
label_dir = Path(dataset_path)/'Labels'
        
# Get sorted lists of NIfTI files using pathlib
dwis = sorted(dwi_dir.glob('*.nii.gz'))
labels = sorted(label_dir.glob('*.nii.gz'))
flairs = sorted(flair_dir.glob('*.nii.gz'))
        
       
# Create dictionary list with string paths
datalist = [ {'dwi': str(dwi), 'flair': str(flair),'label': str(lbl)} for dwi,flair, lbl in zip(dwis,flairs, labels)]       

In [6]:
det_transforms = Compose([
    LoadImaged(keys=["dwi","flair" ,"label"]),
    EnsureChannelFirstd(keys=["dwi","flair", "label"]),
    Resized(keys=["dwi",'flair' ,"label"],spatial_size=(176,176,64),mode=("trilinear","trilinear", "nearest")),
    NormalizeIntensityd(keys=["dwi","flair"], nonzero=True, channel_wise=True),
    ConcatItemsd(keys=["dwi","flair"], name = "image",dim = 0),
    DeleteItemsd(keys=["dwi","flair"])
])

In [10]:
prob = 0.5
rand_transforms = Compose([
    #RandSpatialCropd(keys=["image", "label"],roi_size=(176,176,64),random_size=False),
    RandFlipd(keys=["image", "label"], prob=prob, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=prob, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=prob, spatial_axis=2),
    RandRotate90d(keys=["image", "label"],max_k=3, prob=prob),
    #Rand3DElasticd(keys=["image", "label"], sigma_range=(5,7),magnitude_range=(50,100),padding_mode='zeros',prob=0.2, mode=("trilinear","nearest")),
    RandBiasFieldd(keys=["image"],coeff_range=(0.1,0.2),prob=prob),
    RandShiftIntensityd(keys=["image"], offsets=0.1, prob=prob)
])

In [11]:
train_preprocessing = Compose([det_transforms,rand_transforms])
val_preprocessing = Compose([det_transforms])

In [14]:
X = train_preprocessing(datalist[20])
X['image'].shape


torch.Size([2, 176, 176, 64])

In [15]:
train_ds = Dataset(
            data=datalist[0:200],
            transform=train_preprocessing,
        )
val_ds = Dataset(
            data=datalist[201:249],
            transform=val_preprocessing, 
        )

In [16]:
pretrained_path = 'E:\\Codigos\\Python\\Brain Stroke Segmentation\\pretrained_model.pt'

In [17]:
from monai.networks.nets import UNet, SegResNet # Changed to UNet

In [22]:
model = SegResNet(
            spatial_dims=3,
            in_channels=2,
            out_channels=1,
            init_filters=16,
            blocks_down=(1, 2, 2, 4),
            blocks_up=(1, 1, 1),
            dropout_prob=0.2
        ).to('cpu')

In [23]:
pretrained_dict = torch.load(pretrained_path, map_location='cpu')
model_dict = model.state_dict()

In [21]:
pretrained_dict['convInit.conv.weight'].shape

torch.Size([16, 4, 3, 3, 3])

In [24]:
model_dict['convInit.conv.weight'].shape

torch.Size([16, 2, 3, 3, 3])

In [32]:
# Handle input channel mismatch
#if 'convInit.conv.weight' in pretrained_dict:
# Pretrained has 4 input channels, we have 2
init_weights = pretrained_dict['convInit.conv.weight']
out_weights = pretrained_dict['conv_final.2.conv.weight'] 
out_bias = pretrained_dict['conv_final.2.conv.bias']  

In [33]:
adapted_init_weights = torch.cat((init_weights.mean(dim=1, keepdim=True), init_weights.mean(dim=1, keepdim=True)), dim=1) 
adapted_out_weights = out_weights.mean(dim=0, keepdim=True)
adapted_out_bias = out_bias.mean(dim=0, keepdim=True)  

In [34]:
# Option 2: Use just the first channel (if you know it's most relevant)
# adapted_weights = pretrained_weights[:, :1, :, :, :].clone()
pretrained_dict['convInit.conv.weight'] = adapted_init_weights
pretrained_dict['conv_final.2.conv.weight'] = adapted_out_weights
pretrained_dict['conv_final.2.conv.bias'] = adapted_out_bias

In [35]:
# 1. Filter out unnecessary keys (mismatched layers)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.shape == model_dict[k].shape}

In [36]:
model_dict.update(pretrained_dict)

In [37]:
model.load_state_dict(model_dict, strict=False)

<All keys matched successfully>