In [2]:
from dataset import BasicDataset
import matplotlib.pyplot as plt

# Create the two datasets
filepath = '/vol/bitbucket/dks20/renal_ssn/labelbox_download/train_data.h5'
dataset = BasicDataset(filepath, augment=False)
augmented_dataset = BasicDataset(filepath, augment=True)
# test_dataset = TestDataset(filepath)

# Plot images from both datasets using matplotlib
# for i in range(20):
#     img, mask = dataset.__getitem__(i)
#     img = img / 255.0
#     img = img.permute(1, 2, 0).cpu().numpy()
#     mask = mask.cpu().numpy()

#     aug_img, aug_mask = augmented_dataset.__getitem__(i)
#     aug_img = aug_img.permute(1, 2, 0).cpu().numpy()
#     aug_mask = aug_mask.cpu().numpy()

    # plt.figure(figsize=(10, 10))
    # plt.subplot(2, 2, 1)
    # plt.imshow(img)
    # plt.title("Original Image")
    # plt.axis('off')

    # plt.subplot(2, 2, 2)
    # plt.imshow(mask)
    # plt.title("Original Mask")
    # plt.axis('off')

    # plt.subplot(2, 2, 3)
    # plt.imshow(aug_img)
    # plt.title("Augmented Image")
    # plt.axis('off')

    # plt.subplot(2, 2, 4)
    # plt.imshow(aug_mask)
    # plt.title("Augmented Mask")
    # plt.axis('off')

    # plt.show()

In [3]:
import torch
from torch.utils.data import DataLoader, random_split
import numpy as np

def get_train_loaders(train_path="/vol/bitbucket/dks20/renal_ssn/labelbox_download/train_data.h5",
                      batch_size=1,
                      augment=False,
                      split_ratio=0.9,
                      oversample=False,
                      oversample_classes=[2, 3],
                      oversample_weight=4,
                      seed=42):
    """ Get train and validation loaders for kidney dataset, with optional oversampling of minority classes

    Args:
        train_path (str): path to training data
        batch_size (int): batch size
        augment (bool): whether to apply data augmentation
        split_ratio (float): fraction of data to use for training
        oversample (bool): whether to oversample minority classes
        oversample_classes (list): list of classes to oversample
        oversample_weight (int): weight to assign to minority classes
        seed (int): random seed for reproducibility

    Returns:
        train_loader (DataLoader): training DataLoader
        val_loader (DataLoader): validation DataLoader
    """

    if not (0 < split_ratio < 1):
        raise ValueError("split_ratio must be between 0 and 1.")
    if batch_size <= 0:
        raise ValueError("batch_size must be a positive integer.")

    train_dataset = BasicDataset(train_path, augment=augment)
    print("Number of training samples: ", len(train_dataset))

    # Split data into training and validation sets using seed
    train_len = int(split_ratio * len(train_dataset))
    val_len = len(train_dataset) - train_len

    train_set, val_set = random_split(train_dataset, [train_len, val_len], generator=torch.Generator().manual_seed(seed))

    # Minority class oversampling
    sampler = None

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=(not oversample), num_workers=0, pin_memory=True, sampler=sampler)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    return train_loader, val_loader

train_loader, val_loader = get_train_loaders()

length = len(train_loader)
counter = 0
normal_counter = 0

for img, mask in train_loader:
    if any(c in mask for c in [2,3]):
        counter += 1
    
    # if mask.sum() == 0:
    #     counter += 1
    #     if counter >= 1:
    #         break
    # else:
    #     if normal_counter >= 1:
    #         continue

    #     normal_counter += 1
    #     print("non empty mask")
    #     print(f"img counts: {np.unique(img, return_counts=True)}")
    #     print(f"img min: {img.min()}, img max: {img.max()}")

    #     # plot img
    #     img = img / 255.0
    #     img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    #     plt.imshow(img)
    #     plt.show()

print(f"Number of oversampled masks: {counter}/{length}")

Number of training samples:  5748
Number of oversampled masks: 1648/5173


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

model = UNet()  # Assuming UNet is already defined
criterion = nn.BCEWithLogitsLoss()  # Or any appropriate loss function
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)  # Assuming you have a DataLoader

# Initial setup
init_lr = 1e-7
max_lr = 1e-2
num_batches = 100  # Total number of batches for which we test learning rate
lr_mult = (max_lr / init_lr) ** (1 / num_batches)
optimizer = optim.Adam(model.parameters(), lr=init_lr)

lrs = []
losses = []
best_loss = float('inf')

model.train()
for i, (images, labels) in enumerate(train_loader):
    if i >= num_batches:
        break
    
    # Forward pass
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    loss = criterion(outputs, labels)
    
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Update learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] *= lr_mult
        lrs.append(param_group['lr'])
    
    # Store loss
    if loss.item() < best_loss:
        best_loss = loss.item()
    losses.append(loss.item())
    
    # Reset weights to initial state for fair comparison (optional)
    # model.apply(weight_reset) if you implement a weight reset function

# Plotting results
import matplotlib.pyplot as plt
plt.plot(lrs, losses)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Loss')
plt.title('Learning Rate Range Test')
plt.show()


In [None]:
from sklearn.model_selection import KFold
from torch.utils.data import SubsetRandomSampler, DataLoader
import numpy as np

k_folds = 5
num_epochs = 10
dataset = ...  # Your dataset
kf = KFold(n_splits=k_folds, shuffle=True)

results = []

for fold, (train_ids, test_ids) in enumerate(kf.split(dataset)):
    # Randomly sample hyperparameters
    learning_rate = 10**np.random.uniform(-6, -1)
    batch_size = np.random.choice([16, 32, 64])
    
    train_sampler = SubsetRandomSampler(train_ids)
    test_sampler = SubsetRandomSampler(test_ids)
    
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
    
    model = UNet()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCEWithLogitsLoss()
    
    # Training loop for the fold
    for epoch in range(num_epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Evaluate
        model.eval()
        with torch.no_grad():
            total_loss = 0
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                total_loss += loss.item()

        avg_loss = total_loss / len(test_loader)
        print(f"Fold {fold}, Epoch {epoch}, Avg. Loss: {avg_loss}")
    
    results.append({'fold': fold, 'lr': learning_rate, 'batch_size': batch_size, 'loss': avg_loss})

# Evaluate the results to find the best hyperparameters
best_run = sorted(results, key=lambda x: x['loss'])[0]
print("Best Hyperparameters:", best_run)