In [6]:
import argparse
from multissl.data import get_transform, tifffile_loader, MixedUAVDataset
import tifffile as tiff
from multissl.models import build_model
import pytorch_lightning as pl

import torch

from lightly.data import LightlyDataset
from lightly.transforms.multi_view_transform import MultiViewTransform
import pytorch_lightning as pl
import torch
import numpy as np
from PIL import Image
from typing import Dict, Optional, Tuple, List, Any

In [7]:

import torch
from torch.utils.data.dataloader import default_collate

def restack_nested(list_of_tensor_lists):
    #     Determine the number of tensor positions from the first list
    num_positions = len(list_of_tensor_lists[0])
    
    # Initialize empty lists to collect tensors at each position
    collected_tensors = [[] for _ in range(num_positions)]
    
    # Collect tensors by their position in each inner list
    for tensor_list in list_of_tensor_lists:
        for i, tensor in enumerate(tensor_list):
            collected_tensors[i].append(tensor)
    
    # Stack each collection of tensors
    return [torch.stack(tensor_collection) for tensor_collection in collected_tensors]
    


def multisensor_views_collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, Any]:
    """Collect samples by type for separate processing"""
    result = {
        'rgb_only': {'data': [], 'indices': []},
        'ms_only': {'data': [], 'indices': []},
        'aligned': {'rgb': [], 'ms': [], 'indices': []}
    }
    
    for i, item in enumerate(batch):
        if item['type'] == 'rgb_only':
            result['rgb_only']['data'].append(item['rgb'])
            result['rgb_only']['indices'].append(i)
        elif item['type'] == 'ms_only':
            result['ms_only']['data'].append(item['ms'])
            result['ms_only']['indices'].append(i)
        elif item['type'] == 'aligned':
            result['aligned']['rgb'].append(item['rgb'])
            result['aligned']['ms'].append(item['ms'])
            result['aligned']['indices'].append(i)
    
    # Convert lists to tensors where applicable
    if result['rgb_only']['data']:
        result['rgb_only']['data'] = restack_nested(result['rgb_only']['data'])
    if result['ms_only']['data']:
        result['ms_only']['data'] = restack_nested(result['ms_only']['data'])
    if result['aligned']['rgb']:
        result['aligned']['rgb'] = restack_nested(result['aligned']['rgb'])
        result['aligned']['ms'] = restack_nested(result['aligned']['ms'])
    
    # Add original batch size for reference
    result['batch_size'] = len(batch)
    
    return result

    
# Create a multiview transform that returns three different augmentations of each image.
transform_multispectral = get_transform(img_size=224, std_noise  =0.1, 
                                        brightness_factor=0.1 ,
                                        max_shift=0.1)
tfs = [transform_multispectral for i in range(4)]

transform_ms = MultiViewTransform(transforms=tfs)

pl.seed_everything(42)


dataset_train_msrgb = MixedUAVDataset(
    root_dir ="../../msdata/data/output_multi/" ,
    transform = transform_ms,
    reduce_by = 160
)

length_dataset = len(dataset_train_msrgb)
print("Loaded dataset, dataset size: "+ str(length_dataset))


dataloader_train_msrgb = torch.utils.data.DataLoader(
    dataset_train_msrgb,                            # Pass the dataset to the dataloader.
    batch_size=16,         # A large batch size helps with learning.
    shuffle=True,                       # Shuffling is important!
    drop_last = True,
    num_workers=0,
    collate_fn = multisensor_views_collate_fn
)




Seed set to 42


Loaded dataset, dataset size: 480


In [8]:
batch = next(iter(dataloader_train_msrgb))

In [13]:
len(batch["rgb_only"]["data"])

4

In [96]:
# Create a dataset from your image folder.
dataset_train_ms = LightlyDataset(
    input_dir = "../../msdata/data/output_multi/RGBMS",
    transform = transform_ms,
)

def jpg_loader(f):
    return np.array(Image.open(f))
def tifffile_loader(f):
    img_path = f
    with tiff.TiffFile(img_path) as tif:
        image_array = tif.asarray()
        print(f"Array shape: {image_array.shape}")
    return image_array
dataset_train_ms.dataset.loader = tifffile_loader

length_dataset = len(dataset_train_ms)
print("Loaded dataset, dataset size: "+ str(length_dataset))
    
# Build a PyTorch dataloader.
dataloader_train_ms = torch.utils.data.DataLoader(
    dataset_train_ms,                            # Pass the dataset to the dataloader.
    batch_size=16,         # A large batch size helps with learning.
    shuffle=True,                       # Shuffling is important!
    drop_last = True,
    num_workers=0

)

Loaded dataset, dataset size: 116608


In [97]:
batch2 = next(iter(dataloader_train_ms))

Array shape: (512, 512, 4)
Array shape: (512, 512, 4)
Array shape: (512, 512, 4)
Array shape: (512, 512, 3)
Array shape: (512, 512, 4)
Array shape: (512, 512, 4)
Array shape: (512, 512, 4)
Array shape: (512, 512, 4)
Array shape: (512, 512, 4)
Array shape: (512, 512, 3)
Array shape: (512, 512, 3)
Array shape: (512, 512, 4)
Array shape: (512, 512, 4)
Array shape: (512, 512, 3)
Array shape: (512, 512, 4)
Array shape: (512, 512, 4)


RuntimeError: stack expects each tensor to be equal size, but got [4, 224, 224] at entry 0 and [3, 224, 224] at entry 3

In [81]:
len(batch2[0][0])

16