In [1]:
from data import dataset, PlantOrgansDataset
from preprocessing import preprocess_image_and_mask
import torchvision.transforms.v2 as T
import torch
import numpy as np
from alexnet import MyTransform, SlidingWindow, ExtractFeatures, get_extractor, get_feature, get_model
from train import device, pixel_validate, patch_loss, patch_validate, evaluate, fit
import torch.utils.data as data_utils
from kmeans import KMeans, KNN
from evaluate import calculate_metrics
import torchmetrics
from torch.utils.data import DataLoader
from evaluate import perform_segmentation, segmentation_image, retrieve_features

Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Using cache found in C:\Users\pc/.cache\torch\hub\pytorch_vision_v0.10.0
Using cache found in C:\Users\pc/.cache\torch\hub\pytorch_vision_v0.10.0


In [2]:
commonTransform = T.Compose([
        T.Resize(size=(2048, 2048)),
        T.ToImage()
        
        # T.RandomHorizontalFlip(p=0.5),
        # T.RandomVerticalFlip(p=0.5),
        # T.RandomRotation(degrees=45)
    ])
imagesTransform = T.Compose([
    T.ToDtype(torch.float32, scale=False),
    # T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    MyTransform(64),
    T.Resize((224, 224)),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
masksTransform = T.Compose([
    T.ToDtype(torch.int8, scale=False),
    # T.Normalize(mean=[0.0014], std=[0.0031]),
    MyTransform(64),
    # T.Resize((224, 224))
])

In [3]:
train_validation_data = dataset['train'].train_test_split(test_size=0.2, seed=42)
train_dataset = PlantOrgansDataset(train_validation_data['train'], commonTransform, imagesTransform, masksTransform)
validation_dataset = PlantOrgansDataset(train_validation_data['test'], commonTransform, imagesTransform, masksTransform)
test_dataset = PlantOrgansDataset(dataset['validation'], commonTransform, imagesTransform, masksTransform)


In [4]:
cross_entropy_weights = torch.tensor([
        4.8033e-04,
        6.4129e-03,
        3.9272e-03,
        9.7140e-01,
        1.7778e-02], device=device)

In [5]:
print("train_dataset: ", len(train_dataset))
print("validation_dataset: ", len(validation_dataset))
print("test_dataset: ", len(test_dataset))

train_dataset:  4596
validation_dataset:  1149
test_dataset:  1437


In [6]:
class WrappedDataLoader:
    def __init__(self, loader, func):
        self.loader = loader
        self.func = func

    def __len__(self):
        return len(self.loader)

    def __iter__(self):
        for batch in iter(self.loader):
            batch_cuda = []
            for X, y in batch:
                batch_cuda.append(self.func(X, y))
            yield batch_cuda

In [7]:
def to_device(X: torch.Tensor, y: torch.Tensor):
    return X.to(device, dtype=torch.float32), y.to(device, dtype=torch.int8)

In [8]:
batch_size = 1024

In [9]:
def custom_collate_fn(batch):
    batchs_amount = len(batch)
    current_images = []
    current_masks = []
    current_length = 0
    i = 0
    while i < batchs_amount or current_length >= batch_size:
        if current_length == batch_size:
            if len(current_images) == 1:
                result_images = current_images[0]
                result_masks = current_masks[0]
            else:
                result_images = torch.concatenate(current_images)
                result_masks = torch.concatenate(current_masks)
            current_images = []
            current_masks = []
            current_length = 0
            yield result_images, result_masks
        elif current_length > batch_size:
            concatenated_images = torch.concatenate(current_images)
            concatenated_masks = torch.concatenate(current_masks)
            images_split = torch.split(concatenated_images, batch_size, dim=0)
            masks_split = torch.split(concatenated_masks, batch_size, dim=0)
            current_images = [images_split[len(images_split) - 1]]
            current_masks = [masks_split[len(masks_split) - 1]]
            current_length = len(current_images[0])
            for j in range(len(images_split) - 1):
                yield images_split[j], masks_split[j]
        else:  
            images, masks = batch[i]
            i += 1
            current_length += len(images)
            current_images.append(images)
            current_masks.append(masks)
    if current_length > 0:
        concatenated_images = torch.concatenate(current_images)
        concatenated_masks = torch.concatenate(current_masks)
        yield concatenated_images, concatenated_masks



In [10]:

train_loader = WrappedDataLoader(
    DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, 
               pin_memory=True, pin_memory_device=[device]), to_device)
valid_loader = WrappedDataLoader(
    DataLoader(validation_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn,
               pin_memory=False, pin_memory_device=[device]), to_device)
test_loader = WrappedDataLoader(
    DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn,
               pin_memory=False, pin_memory_device=[device]), to_device)

In [11]:
import os
import torch.optim as optim
import time
from ray import tune
from ray.train import Checkpoint, get_checkpoint, report, RunConfig
from ray.tune.schedulers import ASHAScheduler

2024-11-17 01:06:48,040	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-11-17 01:06:48,347	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [12]:
src_path = "C:\\Users\\pc\\Documents\\repos\\mp-2\\nn\\nn-lab2\\"

constants = {
    "criterion": torch.nn.CrossEntropyLoss(),
    "lr": 0.0001,
    "n_epochs": 40,
    "saving_model_path": src_path + "models\\raytune"
}
config = {
    "batch_size": tune.grid_search([64*64]),
    "patch_size": tune.grid_search([32])
}

In [13]:
image_to_tensor = T.Compose([
    # T.ToImage(),
    T.ToDtype(dtype=torch.float32, scale=True),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    T.Resize(size=(2048, 2048)),
])
mask_to_tensor = T.Compose([
    T.ToImage(),
    T.ToDtype(dtype=torch.float32, scale=False)
])
mask_of_uniform_size = T.Compose([
    T.Resize((2048, 2048), interpolation=T.InterpolationMode.NEAREST_EXACT),
    mask_to_tensor,
])



In [14]:
image = train_validation_data['train'][1]['image']
X = image_to_tensor(image).unsqueeze(0).to(device)
y = mask_of_uniform_size(train_validation_data['train'][1]['label']).to(device)

In [15]:
# T.Resize((224, 224))(train_validation_data['train'][1]['image'])
# T.Resize((224, 224))(mask_to_image(train_validation_data['train'][1]['label']))

from torchvision.models.feature_extraction import get_graph_node_names
# get_graph_node_names(model)
torch.cuda.empty_cache()


# import torch
# import gc
# objects = []
# for obj in gc.get_objects():
#     try:
#         if (torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data))) and obj.is_cuda:
#             objects.append((type(obj), obj.size(), obj.numel()))
#     except:
#         pass

# sorted_by_size = sorted(objects, key=lambda tup: tup[2], reverse=True)
# sorted_by_size

In [16]:
# predicted_classes = perform_segmentation(X, y.unsqueeze(0))
# segmentation_image(predicted_classes)

In [17]:
del X, y
torch.cuda.empty_cache()

In [18]:
# segmentation_image(mask_to_tensor(train_validation_data['train'][1]['label'])[0])

In [19]:
# image = train_validation_data['train'][1]['image']
# X = image_to_tensor(image).unsqueeze(0).to(device)

In [20]:
# import tqdm
# for i in tqdm.tqdm(range(100)):
#     image = train_validation_data['train'][i]['image']
#     features = retrieve_features(image_to_tensor(image).unsqueeze(0).to(device))
#     save_features(i, features)

In [21]:
import pickle
import os
import lz4

def read_features(index, src_path, layer_name="classifier.0", sliding_window_size=32, 
                         sliding_window_step=11):
    file_name = os.path.join(src_path, "features", 
                            "train_{layer_name}_{window_size}_{window_step}_{i}.lz4"
                              .format(layer_name=layer_name, window_size=sliding_window_size, window_step=sliding_window_step, i=index)
                            )
    with lz4.frame.open(file_name, mode="rb") as f:
      features_shape = pickle.load(f)
      features = pickle.load(f)
      return features

In [22]:
# knn = KNN(features.view(features.size(0) * features.size(1), -1), ground_truth_patches, k=5)

In [23]:
knn_features = []
for i in range(5):
    knn_features.append(read_features(i, src_path).to("cpu"))
    knn_features[i] = knn_features[i].view(knn_features[i].size(0) * knn_features[i].size(1), -1)
torch.cuda.empty_cache()
merged_knn_features = torch.cat(knn_features, dim=0).to(device)

  return torch.load(io.BytesIO(b))


In [24]:
torch.cuda.empty_cache()

In [25]:
def get_ground_truth_for_indices(indices):
    ground_truths = []
    for i in indices:
        ground_truth = mask_of_uniform_size(train_validation_data['train'][i]['label']).to("cpu")
        mask_patches = SlidingWindow(32, 11)(ground_truth.unsqueeze(0))
        mask_patches_ = mask_patches.view(mask_patches.size(0), mask_patches.size(1), -1)
        ground_truths.append(torch.mode(mask_patches_, dim=2).values.to(dtype=torch.int8).view(-1))
    merged_ground_truths = torch.cat(ground_truths, dim=0).to(device)
    del ground_truths, ground_truth, mask_patches, mask_patches_
    return merged_ground_truths

In [26]:
ground_truths = get_ground_truth_for_indices(range(10))

In [27]:
ground_truths

tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0', dtype=torch.int8)

In [28]:
knn = KNN(merged_knn_features, ground_truths, 5)

In [29]:
test_features = read_features(11, src_path).to("cpu")
torch.cuda.empty_cache()

In [41]:
test_features.shape

torch.Size([187, 187, 9216])

In [None]:
predicted_mask = torch.zeros((187, 187))
for i in range(0, 187, 24):
    i_upper_limit = min(i+23, 187)
    predicted_classes = knn.predict(test_features[i : i_upper_limit].view(-1, 9216).to(device, dtype=torch.int64))
    torch.cuda.empty_cache()
    predicted_mask[i : i_upper_limit] = predicted_classes.view(-1, 187, 9216).to("cpu")
    print(i)
        
    

KeyboardInterrupt: 

In [44]:
torch.cuda.empty_cache()