In [1]:
from data import dataset, PlantOrgansDataset
from preprocessing import preprocess_image_and_mask
import torchvision.transforms.v2 as T
import torch
import numpy as np
from alexnet import MyTransform, SlidingWindow, ExtractFeatures, get_extractor, get_feature
from train import device, pixel_validate, patch_loss, patch_validate, evaluate, fit

Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

In [2]:
commonTransform = T.Compose([
        T.Resize(size=(2048, 2048)),
        T.ToImage()
        
        # T.RandomHorizontalFlip(p=0.5),
        # T.RandomVerticalFlip(p=0.5),
        # T.RandomRotation(degrees=45)
    ])
imagesTransform = T.Compose([
    T.ToDtype(torch.float32, scale=False),
    # T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    MyTransform(64),
    T.Resize((224, 224)),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
masksTransform = T.Compose([
    T.ToDtype(torch.int8, scale=False),
    # T.Normalize(mean=[0.0014], std=[0.0031]),
    MyTransform(64),
    # T.Resize((224, 224))
])

In [3]:
train_validation_data = dataset['train'].train_test_split(test_size=0.2, seed=42)
train_dataset = PlantOrgansDataset(train_validation_data['train'], commonTransform, imagesTransform, masksTransform)
validation_dataset = PlantOrgansDataset(train_validation_data['test'], commonTransform, imagesTransform, masksTransform)
test_dataset = PlantOrgansDataset(dataset['validation'], commonTransform, imagesTransform, masksTransform)


In [4]:
tr = T.Compose([
    T.ToImage(),
    T.ToDtype(dtype=torch.int8, scale=False)
])
tensor = tr(train_validation_data['train'][0]['label'])
tensor.max()

tensor(4, dtype=torch.int8)

In [5]:
cross_entropy_weights = torch.tensor([
        4.8033e-04,
        6.4129e-03,
        3.9272e-03,
        9.7140e-01,
        1.7778e-02], device=device)

In [6]:
print("train_dataset: ", len(train_dataset))
print("validation_dataset: ", len(validation_dataset))
print("test_dataset: ", len(test_dataset))

train_dataset:  4596
validation_dataset:  1149
test_dataset:  1437


In [7]:
class WrappedDataLoader:
    def __init__(self, loader, func):
        self.loader = loader
        self.func = func

    def __len__(self):
        return len(self.loader)

    def __iter__(self):
        for batch in iter(self.loader):
            batch_cuda = []
            for X, y in batch:
                batch_cuda.append(self.func(X, y))
            yield batch_cuda

In [8]:
def to_device(X: torch.Tensor, y: torch.Tensor):
    return X.to(device, dtype=torch.float32), y.to(device, dtype=torch.int8)

In [9]:
batch_size = 1024

In [10]:
def custom_collate_fn(batch):
    batchs_amount = len(batch)
    current_images = []
    current_masks = []
    current_length = 0
    i = 0
    while i < batchs_amount or current_length >= batch_size:
        if current_length == batch_size:
            if len(current_images) == 1:
                result_images = current_images[0]
                result_masks = current_masks[0]
            else:
                result_images = torch.concatenate(current_images)
                result_masks = torch.concatenate(current_masks)
            current_images = []
            current_masks = []
            current_length = 0
            yield result_images, result_masks
        elif current_length > batch_size:
            concatenated_images = torch.concatenate(current_images)
            concatenated_masks = torch.concatenate(current_masks)
            images_split = torch.split(concatenated_images, batch_size, dim=0)
            masks_split = torch.split(concatenated_masks, batch_size, dim=0)
            current_images = [images_split[len(images_split) - 1]]
            current_masks = [masks_split[len(masks_split) - 1]]
            current_length = len(current_images[0])
            for j in range(len(images_split) - 1):
                yield images_split[j], masks_split[j]
        else:  
            images, masks = batch[i]
            i += 1
            current_length += len(images)
            current_images.append(images)
            current_masks.append(masks)
    if current_length > 0:
        concatenated_images = torch.concatenate(current_images)
        concatenated_masks = torch.concatenate(current_masks)
        yield concatenated_images, concatenated_masks



In [11]:
from torch.utils.data import DataLoader
train_loader = WrappedDataLoader(
    DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, 
               pin_memory=False, pin_memory_device=[device]), to_device)
valid_loader = WrappedDataLoader(
    DataLoader(validation_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn,
               pin_memory=False, pin_memory_device=[device]), to_device)
test_loader = WrappedDataLoader(
    DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn,
               pin_memory=False, pin_memory_device=[device]), to_device)

In [12]:
to_image = T.ToPILImage()
mask_to_image = T.Compose([
    T.ToTensor(),
    T.ToDtype(torch.float16),
    T.Normalize(mean=[0.0014], std=[0.0031]),
    T.ToPILImage(),
    
])



In [13]:
import os
import torch.optim as optim
import time
from ray import tune
from ray.train import Checkpoint, get_checkpoint, report, RunConfig
from ray.tune.schedulers import ASHAScheduler

2024-11-16 14:26:29,175	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-11-16 14:26:29,692	INFO util.py:154 -- Outdated packages:
  ipywidgets==7.8.1 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [14]:
src_path = "C:\\Users\\pc\\Documents\\repos\\mp-2\\nn\\nn-lab2\\"

constants = {
    "criterion": torch.nn.CrossEntropyLoss(),
    "lr": 0.0001,
    "n_epochs": 40,
    "saving_model_path": src_path + "models\\raytune"
}
config = {
    "batch_size": tune.grid_search([64*64]),
    "patch_size": tune.grid_search([32])
}

In [15]:
image_to_tensor = T.Compose([
    
    # T.ToImage(),
    T.ToDtype(dtype=torch.float32, scale=True),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    T.Resize(size=(2048, 2048)),
])




In [16]:
image = train_validation_data['train'][0]['image']
X = image_to_tensor(image).unsqueeze(0).to(device)

In [17]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True).to(device)
model.eval()
model.requires_grad_(False)

Using cache found in C:\Users\pc/.cache\torch\hub\pytorch_vision_v0.10.0


AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [None]:
image_to_sliding_patches = T.Compose([
    SlidingWindow(16, 6)
])

In [19]:
X.shape

torch.Size([1, 3, 2048, 2048])

In [20]:
patches = image_to_sliding_patches(X)

In [21]:
patches.shape

torch.Size([512, 512, 3, 16, 16])

In [22]:
features = torch.zeros(patches.size(0), patches.size(1), 9216)

In [23]:
upscale = T.Compose([T.Resize(size=(224, 224))])

In [24]:
import torch.utils.data as data_utils
patches_dataset = data_utils.TensorDataset(patches)
patches_loader = DataLoader(patches_dataset, batch_size=1, shuffle=False)

i_h = 0
for batch in patches_loader:
    print(i_h)
    upscaled = upscale(batch[0][0])
    model(upscaled)
    features[i_h] = get_feature(upscaled, get_extractor(device, model, "avgpool"), "avgpool").view(batch[0][0].size(0), -1).to("cpu")
    i_h += 1

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [25]:
torch.cuda.empty_cache()

In [30]:
features = features.to(device, dtype=torch.float16)

In [27]:
from kmeans import KMeans 

In [28]:
features.view(-1, features.size(2)).shape

torch.Size([262144, 9216])

In [32]:
image_means = KMeans(features.view(-1, features.size(2)), 5)

OutOfMemoryError: CUDA out of memory. Tried to allocate 22.50 GiB. GPU 0 has a total capacity of 11.00 GiB of which 0 bytes is free. Of the allocated memory 19.32 GiB is allocated by PyTorch, and 138.70 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)