In [1]:
import torch
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from pulse_method.toolbox_pulse import *
from models.networks.Attention_unet import AttentionUnet
from transforms.augmentations import *
from models.layers.grid_attention import *

import os
import json
import random

In [2]:
from pulse_method.toolbox_pulse import thermograms

operator=thermograms()

In [3]:
transforms=[
    RandomPhaseAwareSpeedChange(),
    RandomBrightnessContrast(),
    PrependFirstFrame(),
    RandomFlip3D(axes=(1,2)),
    RandomElasticTransform(),
    RandomSequenceRotation(),
    RandomCropSequence(),
    AddGaussianNoise3D(),
    NormalizeTo01()
]     

In [4]:
modification=Compose3D(transforms=transforms)

In [5]:
def create_splits(data_dir, output_dir, seed=42):
    random.seed(seed)

    # 1. Collect all file paths
    all_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))]
    all_files.sort()  # optional: keep consistent ordering
    random.shuffle(all_files)

    # 2. Split off 10% for validation
    n_total = len(all_files)
    n_val = int(0.1 * n_total)

    val_files = all_files[:n_val]
    remaining_files = all_files[n_val:]

    # 3. Create subsets: 50%, 75%, 100% of remaining
    n_rem = len(remaining_files)

    subsets = {
        "50": remaining_files[:int(0.5 * n_rem)],
        "75": remaining_files[:int(0.75 * n_rem)],
        "100": remaining_files
    }

    # 4. Save JSON files
    os.makedirs(output_dir, exist_ok=True)

    for key, train_files in subsets.items():
        split = {
            "train": train_files,
            "val": val_files
        }
        out_path = os.path.join(output_dir, f"split_{key}.json")
        with open(out_path, "w") as f:
            json.dump(split, f, indent=4)
        print(f"Saved {out_path} with {len(train_files)} train and {len(val_files)} val samples.")


In [6]:
create_splits("/mnt/43e5e0ce-4877-4cb7-9293-17b386c78736/attention_unet/data/All_data", "/mnt/43e5e0ce-4877-4cb7-9293-17b386c78736/attention_unet/splits")

Saved /mnt/43e5e0ce-4877-4cb7-9293-17b386c78736/attention_unet/splits/split_50.json with 162 train and 36 val samples.
Saved /mnt/43e5e0ce-4877-4cb7-9293-17b386c78736/attention_unet/splits/split_75.json with 243 train and 36 val samples.
Saved /mnt/43e5e0ce-4877-4cb7-9293-17b386c78736/attention_unet/splits/split_100.json with 324 train and 36 val samples.


In [7]:
from datasets.cd_dataset import SequenceDataset

In [8]:
import json

In [9]:
# Importing data splits
with open('splits/split_50.json','r') as f:
    splits = json.load(f)

train_data=splits['train']
val_data=splits['val']

In [None]:
train_dataset=SequenceDataset(train_data,modification,operator)

In [11]:
train_loader=DataLoader(train_dataset,batch_size=1,shuffle=True,num_workers=4)

In [12]:
batch_1=next(iter(train_loader))

In [13]:
batch_1[0].size()

torch.Size([1, 1, 256, 256, 40])

In [14]:
batch_1[1].size()

torch.Size([1, 1, 256, 256, 1])