In [1]:
import torch
import torch.utils.data as torch_split
import torch.nn as nn
from torch.nn import MSELoss
import numpy as np
import dataset
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler
from monai.losses import DiceLoss
from monai.losses import FocalLoss
from monai.networks.nets import UNet
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/postprocessing')
import visual

import metrics

In [2]:
path = "H:/Projects/Kaggle/CZII-CryoET-Object-Identification/datasets/3D/dim104-heat-map-700"
data = dataset.UNetDataset(path=path)

tv_split = 0.8
trn = int(len(data) * tv_split)
val = len(data) - trn

# train_dataset, val_dataset = torch_split.random_split(data, [trn, val])

train_dataset = dataset.UNetDataset(path=path, train=True)
val_dataset = dataset.UNetDataset(path=path, val=True)

labels = [
"background",
"apo-ferritin(E)",
"beta-amylase(NS)",
"beta-galactosidase(H)",
"ribosome(E)",
"thyroglobulin(H)",
"virus-like-particle(E)"
]

In [3]:
import optuna
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from optuna.pruners import MedianPruner

# Define the objective function for Optuna
def objective(trial):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # ------------------------------ #
    #        HYPERPARAMETERS         #
    # ------------------------------ #
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    decay = trial.suggest_float('decay', 0.3, 1.0)
    # dropout = trial.suggest_float("dropout", 0.25, 0.5)
    dropout = 0.3
    # regularization_strength = trial.suggest_float("regularization_strength", 1e-4, 1e-2, log=True)
    # alpha = trial.suggest_float("alpha", 0.25, 1.0)
    # theta = trial.suggest_float("theta", 0.1, 0.9)
    # theta = 0.6
    # gamma = trial.suggest_float("gamma", 2.0, 5.0)
    
    # Model initialization
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=7,
        channels=(64, 128, 256, 512),
        strides=(2, 2, 2),
        num_res_units=2,
        dropout=dropout,
    ).to(device)
    
    num_epochs = 30
    batch_size = 16

    # ------------------------------ #
    #        TRAINING METHODS        #
    # ------------------------------ #
    weights = torch.tensor([0.0434743, 1.16546, 1.1661, 1.16513, 1.14281, 1.15554, 1.16149]).to(device)  # Example weights for classes

    # dice_loss = DiceLoss(to_onehot_y=False, softmax=True, weight=weights).to(device)
    # focal_loss = FocalLoss(to_onehot_y=False, use_softmax=True, weight=weights, gamma=gamma ).to(device)
    mse_loss = nn.MSELoss(reduction='none')
    
    optimizer = Adam(model.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=decay)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False, num_workers=4)

    def add_regularization_loss(model, regularization_type, regularization_strength):
        reg_loss = 0
        for param in model.parameters():
            reg_loss += torch.sum(param ** 2)
        return regularization_strength * reg_loss

    # ------------------------------ #
    #             TRAIN              #
    # ------------------------------ #
    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            input, target = batch['src'].to(device), batch['tgt'].to(device)
            optimizer.zero_grad()
            output = model(input)
            # loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)
            indv_loss = mse_loss(output, target)
            weighted_loss = (indv_loss * weights.view(1, -1, 1, 1, 1)).mean()
            loss = weighted_loss
            
            # reg_loss = add_regularization_loss(model, "L2", regularization_strength)

            loss.backward()
            optimizer.step()

        scheduler.step()
            
        # ---------- #
        # VALIDATION #
        # ---------- #
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                input, target = batch['src'].float().to(device), batch['tgt'].long().to(device)
                output = model(input)
                # loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)
                loss = mse_loss(output, target)
                weighted_loss = (loss * weights.view(1, -1, 1, 1, 1)).mean()
                val_loss += weighted_loss.item()
                
        # print("batch done")
        val_loss /= len(val_loader)
        
        trial.report(val_loss, epoch)
        print(f"Epoch {epoch} loss: {val_loss}")
        
        if trial.should_prune():
            
            raise optuna.exceptions.TrialPruned()

    return val_loss

n_epochs = 25

study = optuna.create_study(direction="minimize", pruner=MedianPruner(n_startup_trials=4, n_warmup_steps=25))
study.optimize(objective, n_trials=15)

print("Best hyperparameters:", study.best_params)

[I 2024-12-30 08:17:37,592] A new study created in memory with name: no-name-b97bc9c1-8be6-4593-ad7a-82cef1a83e25


Epoch 0 loss: 0.19253355099095237
Epoch 1 loss: 0.11746144874228372
Epoch 2 loss: 0.0855516286359893
Epoch 3 loss: 0.06825231264034907
Epoch 4 loss: 0.058345901055468455
Epoch 5 loss: 0.05078086960646841
Epoch 6 loss: 0.04474914198120435
Epoch 7 loss: 0.039943107300334506
Epoch 8 loss: 0.03659689881735378
Epoch 9 loss: 0.03378640198045307
Epoch 10 loss: 0.0314084357685513
Epoch 11 loss: 0.0293273131052653
Epoch 12 loss: 0.027723345905542374
Epoch 13 loss: 0.026302766882710986
Epoch 14 loss: 0.02498713756601016
Epoch 15 loss: 0.023873560337556735
Epoch 16 loss: 0.022953530980481043
Epoch 17 loss: 0.0220837342656321
Epoch 18 loss: 0.021223079413175583
Epoch 19 loss: 0.020515827875998285
Epoch 20 loss: 0.019931494982706174
Epoch 21 loss: 0.019347550968329113
Epoch 22 loss: 0.018852649049626455
Epoch 23 loss: 0.01835367062853442
Epoch 24 loss: 0.01790864910516474
Epoch 25 loss: 0.017505693559845287
Epoch 26 loss: 0.017096587353282504
Epoch 27 loss: 0.01678870576951239
Epoch 28 loss: 0.0164

[I 2024-12-30 08:52:50,073] Trial 0 finished with value: 0.016113667231467035 and parameters: {'lr': 1.985179111620316e-05, 'decay': 0.8311906302523631}. Best is trial 0 with value: 0.016113667231467035.


Epoch 29 loss: 0.016113667231467035
Epoch 0 loss: 0.12776863078276315
Epoch 1 loss: 0.0702606663107872
Epoch 2 loss: 0.0479134540590975
Epoch 3 loss: 0.03514339733454916
Epoch 4 loss: 0.02781997538275189
Epoch 5 loss: 0.022498129556576412
Epoch 6 loss: 0.018812078775631055
Epoch 7 loss: 0.016030598017904494
Epoch 8 loss: 0.014188543893396854
Epoch 9 loss: 0.012756804728673564
Epoch 10 loss: 0.011563996163507303
Epoch 11 loss: 0.010616312631302409
Epoch 12 loss: 0.009907609472672144
Epoch 13 loss: 0.009328823847075304
Epoch 14 loss: 0.00882801247967614
Epoch 15 loss: 0.008394705545571115
Epoch 16 loss: 0.007999508020778498
Epoch 17 loss: 0.007585719796932406
Epoch 18 loss: 0.007507151717113124
Epoch 19 loss: 0.0072546713054180145
Epoch 20 loss: 0.0070174927823245525
Epoch 21 loss: 0.006842651663141118
Epoch 22 loss: 0.00668547638795442
Epoch 23 loss: 0.006527445072101222
Epoch 24 loss: 0.006465236863328351
Epoch 25 loss: 0.00630714138969779
Epoch 26 loss: 0.006137951225456264
Epoch 27 l

[I 2024-12-30 09:33:57,601] Trial 1 finished with value: 0.005857134858767192 and parameters: {'lr': 5.570803796931761e-05, 'decay': 0.9054139490638864}. Best is trial 1 with value: 0.005857134858767192.


Epoch 29 loss: 0.005857134858767192
Epoch 0 loss: 0.14496521486176384
Epoch 1 loss: 0.08318402121464412
Epoch 2 loss: 0.06092141610052851
Epoch 3 loss: 0.048392657190561295
Epoch 4 loss: 0.04047441358367602
Epoch 5 loss: 0.03461386470331086
Epoch 6 loss: 0.030011072133978207
Epoch 7 loss: 0.026576022514038615
Epoch 8 loss: 0.023620856718884573
Epoch 9 loss: 0.0211964568330182
Epoch 10 loss: 0.01915929363005691
Epoch 11 loss: 0.017478869400090642
Epoch 12 loss: 0.015995852028330166
Epoch 13 loss: 0.014783072595794996
Epoch 14 loss: 0.013623074317971865
Epoch 15 loss: 0.012677222386830382
Epoch 16 loss: 0.011965354904532433
Epoch 17 loss: 0.011186282771329084
Epoch 18 loss: 0.010672268250750171
Epoch 19 loss: 0.010105719996823205
Epoch 20 loss: 0.009658277655641237
Epoch 21 loss: 0.009226037085884146
Epoch 22 loss: 0.00888796140336328
Epoch 23 loss: 0.008582152012321684
Epoch 24 loss: 0.00829888621552123
Epoch 25 loss: 0.007983543082243867
Epoch 26 loss: 0.007808267656299803
Epoch 27 los

[I 2024-12-30 10:15:08,482] Trial 2 finished with value: 0.007185239810496569 and parameters: {'lr': 3.338928000526938e-05, 'decay': 0.9552096327693436}. Best is trial 1 with value: 0.005857134858767192.


Epoch 29 loss: 0.007185239810496569
Epoch 0 loss: 0.13957437376181284
Epoch 1 loss: 0.07687035120195812
Epoch 2 loss: 0.05412744316789839
Epoch 3 loss: 0.04019901860091421
Epoch 4 loss: 0.033688304324944816
Epoch 5 loss: 0.028545637097623613
Epoch 6 loss: 0.02456811649931802
Epoch 7 loss: 0.021426697158151202
Epoch 8 loss: 0.019622900419765048
Epoch 9 loss: 0.018051374703645706
Epoch 10 loss: 0.016801605621973675
Epoch 11 loss: 0.015602102296219932
Epoch 12 loss: 0.014801653826402294
Epoch 13 loss: 0.014174140782819854
Epoch 14 loss: 0.013539777758220831
Epoch 15 loss: 0.013031380044089423
Epoch 16 loss: 0.012704828650587134
Epoch 17 loss: 0.012332595367398527
Epoch 18 loss: 0.011999261048105028
Epoch 19 loss: 0.011701410843266381
Epoch 20 loss: 0.011495177116658952
Epoch 21 loss: 0.011320716494487392
Epoch 22 loss: 0.011098002394040426
Epoch 23 loss: 0.010888134957187705
Epoch 24 loss: 0.010785074801080756
Epoch 25 loss: 0.010640380800598197
Epoch 26 loss: 0.010523190100987753
Epoch 2

[I 2024-12-30 10:56:18,641] Trial 3 finished with value: 0.010241792744232548 and parameters: {'lr': 5.327108874232668e-05, 'decay': 0.6958264819928195}. Best is trial 1 with value: 0.005857134858767192.


Epoch 29 loss: 0.010241792744232548
Epoch 0 loss: 0.002617563664292296
Epoch 1 loss: 0.0015917308887259827
Epoch 2 loss: 0.001401357042292754
Epoch 3 loss: 0.0013980808564358288
Epoch 4 loss: 0.001329328761332565
Epoch 5 loss: 0.001348303449857566
Epoch 6 loss: 0.0012468152757113178
Epoch 7 loss: 0.0012058062970431314
Epoch 8 loss: 0.0011929027839667266
Epoch 9 loss: 0.0012398934147010248
Epoch 10 loss: 0.0011928370739850733
Epoch 11 loss: 0.0011820996878668666
Epoch 12 loss: 0.001135827378473348
Epoch 13 loss: 0.0012087572961010868
Epoch 14 loss: 0.00119125302363601
Epoch 15 loss: 0.0012852071473995845
Epoch 16 loss: 0.0012386964064919287
Epoch 17 loss: 0.0011681979910160105
Epoch 18 loss: 0.001209260533667273
Epoch 19 loss: 0.0011682885362663204
Epoch 20 loss: 0.0012061121977037853
Epoch 21 loss: 0.0011538946871749228
Epoch 22 loss: 0.0011809114237419432
Epoch 23 loss: 0.0011563727554554741
Epoch 24 loss: 0.0011615114053711295
Epoch 25 loss: 0.0011634438479733136
Epoch 26 loss: 0.001

[I 2024-12-30 11:37:28,409] Trial 4 finished with value: 0.0011611692090001372 and parameters: {'lr': 0.005244172927782109, 'decay': 0.7763730170145892}. Best is trial 4 with value: 0.0011611692090001372.


Epoch 29 loss: 0.0011611692090001372
Epoch 0 loss: 0.012569226841959689
Epoch 1 loss: 0.004776301586793529
Epoch 2 loss: 0.0031629976712995106
Epoch 3 loss: 0.002359948871243331
Epoch 4 loss: 0.0020237166124085584
Epoch 5 loss: 0.0017405294994306234
Epoch 6 loss: 0.001632227575302952
Epoch 7 loss: 0.0014738697532771362
Epoch 8 loss: 0.00140201511223697
Epoch 9 loss: 0.001395413716737595
Epoch 10 loss: 0.0013608189765363932
Epoch 11 loss: 0.0012903701410525376
Epoch 12 loss: 0.001265456007483105
Epoch 13 loss: 0.0012655214199589358
Epoch 14 loss: 0.0012456391575849718
Epoch 15 loss: 0.001270177825871441
Epoch 16 loss: 0.0012067029278518425
Epoch 17 loss: 0.0012134651058456963
Epoch 18 loss: 0.0012070256052538753
Epoch 19 loss: 0.0012065911691428886
Epoch 20 loss: 0.0012052830231065552
Epoch 21 loss: 0.001196594300886823
Epoch 22 loss: 0.0011970124647228254
Epoch 23 loss: 0.0012215827333016528
Epoch 24 loss: 0.001197314420197573
Epoch 25 loss: 0.0011806176690798667
Epoch 26 loss: 0.00117

[I 2024-12-30 12:18:37,006] Trial 5 finished with value: 0.0011725822696462274 and parameters: {'lr': 0.001101049695775549, 'decay': 0.6771688198855408}. Best is trial 4 with value: 0.0011611692090001372.


Epoch 29 loss: 0.0011725822696462274
Epoch 0 loss: 0.0045801033783290125
Epoch 1 loss: 0.0018830505303210681
Epoch 2 loss: 0.001472932260690464
Epoch 3 loss: 0.0013413692488231594
Epoch 4 loss: 0.0013296226065398918
Epoch 5 loss: 0.0013706048743592368
Epoch 6 loss: 0.001258664428152972
Epoch 7 loss: 0.0012234435028706987
Epoch 8 loss: 0.001536592595382697
Epoch 9 loss: 0.001192740294047528
Epoch 10 loss: 0.0012011510035437015
Epoch 11 loss: 0.0012051386422374183
Epoch 12 loss: 0.0011795333250322277
Epoch 13 loss: 0.0013317544039131866
Epoch 14 loss: 0.0011477803992521432
Epoch 15 loss: 0.0011524789863162571
Epoch 16 loss: 0.001156322670997017
Epoch 17 loss: 0.001172014059395426
Epoch 18 loss: 0.0012235795147717
Epoch 19 loss: 0.001153442374844518
Epoch 20 loss: 0.0011591700345484747
Epoch 21 loss: 0.0011715947184711695
Epoch 22 loss: 0.001152068467086388
Epoch 23 loss: 0.0011788118734127944
Epoch 24 loss: 0.0011584747763764528
Epoch 25 loss: 0.0011503631507770882
Epoch 26 loss: 0.00116

[I 2024-12-30 12:59:52,102] Trial 6 finished with value: 0.0011773365032341746 and parameters: {'lr': 0.0031859780971462727, 'decay': 0.5757470121809716}. Best is trial 4 with value: 0.0011611692090001372.


Epoch 29 loss: 0.0011773365032341746
Epoch 0 loss: 0.020174656477239396
Epoch 1 loss: 0.009477922692894936
Epoch 2 loss: 0.006409192871716287
Epoch 3 loss: 0.005192306979248921
Epoch 4 loss: 0.005021288318352567
Epoch 5 loss: 0.0047052098541624015
Epoch 6 loss: 0.004313412277648847
Epoch 7 loss: 0.0040937479999330305
Epoch 8 loss: 0.003958403805477751
Epoch 9 loss: 0.0038201044468830028
Epoch 10 loss: 0.0037131483097457224
Epoch 11 loss: 0.0035327374417748717
Epoch 12 loss: 0.0035114098185052476
Epoch 13 loss: 0.0034429612052109507
Epoch 14 loss: 0.003356244104603926
Epoch 15 loss: 0.003303590980875823
Epoch 16 loss: 0.003262213250208232
Epoch 17 loss: 0.0032344156514025396
Epoch 18 loss: 0.0032018463179055187
Epoch 19 loss: 0.003172205062583089
Epoch 20 loss: 0.003148844413873222
Epoch 21 loss: 0.0031371175621946654
Epoch 22 loss: 0.003117726262037953
Epoch 23 loss: 0.0030894804642432267
Epoch 24 loss: 0.0030890077662964663
Epoch 25 loss: 0.003078323426759905
Epoch 26 loss: 0.00307116

[I 2024-12-30 13:41:24,014] Trial 7 finished with value: 0.003053103056218889 and parameters: {'lr': 0.0005691122332540043, 'decay': 0.4547819980244929}. Best is trial 4 with value: 0.0011611692090001372.


Epoch 29 loss: 0.003053103056218889
Epoch 0 loss: 0.012631095324953398
Epoch 1 loss: 0.00647738017141819
Epoch 2 loss: 0.004115346301760938
Epoch 3 loss: 0.0032472173786825603
Epoch 4 loss: 0.003110942534274525
Epoch 5 loss: 0.002726926908103956
Epoch 6 loss: 0.002297153334236807
Epoch 7 loss: 0.0020429656013018554
Epoch 8 loss: 0.0018450046837743786
Epoch 9 loss: 0.0016636449662554595
Epoch 10 loss: 0.0015061203157529235
Epoch 11 loss: 0.0014313582279202011
Epoch 12 loss: 0.0013803258900427157
Epoch 13 loss: 0.0013411035761237144
Epoch 14 loss: 0.0013186566454047959
Epoch 15 loss: 0.0012991020533566673
Epoch 16 loss: 0.0012836767645138833
Epoch 17 loss: 0.0012739461122287645
Epoch 18 loss: 0.0012659617415111926
Epoch 19 loss: 0.0012585670143986742
Epoch 20 loss: 0.0012536135295199023
Epoch 21 loss: 0.0012500029341835114
Epoch 22 loss: 0.0012511391865296497
Epoch 23 loss: 0.0012407718329793876
Epoch 24 loss: 0.0013750515257318814
Epoch 25 loss: 0.0012558915833425191
Epoch 26 loss: 0.00

[I 2024-12-30 14:22:54,155] Trial 8 finished with value: 0.0012274232511926028 and parameters: {'lr': 0.0007409412215398597, 'decay': 0.6582337496026821}. Best is trial 4 with value: 0.0011611692090001372.


Epoch 29 loss: 0.0012274232511926028
Epoch 0 loss: 0.002181144534713692
Epoch 1 loss: 0.0014833345259022382
Epoch 2 loss: 0.0014102250311730637
Epoch 3 loss: 0.0012770439110075433
Epoch 4 loss: 0.0011846806802269486
Epoch 5 loss: 0.001106404712320202
Epoch 6 loss: 0.0013088308041915298
Epoch 7 loss: 0.0011372130959191257
Epoch 8 loss: 0.0014295492761044039
Epoch 9 loss: 0.0011541202483284804
Epoch 10 loss: 0.0011516613497709234
Epoch 11 loss: 0.001129748791249262
Epoch 12 loss: 0.0013300058069742387
Epoch 13 loss: 0.001127000445396536
Epoch 14 loss: 0.001168181563520597
Epoch 15 loss: 0.0011331621660954421
Epoch 16 loss: 0.0011574560129601094
Epoch 17 loss: 0.0011477165389806032
Epoch 18 loss: 0.0011337480326700541
Epoch 19 loss: 0.0012315994956427151
Epoch 20 loss: 0.0010832372742394607
Epoch 21 loss: 0.0011440766716582908
Epoch 22 loss: 0.0011483490492941604
Epoch 23 loss: 0.00117105211959117
Epoch 24 loss: 0.0011433824482891294
Epoch 25 loss: 0.00114957919706487
Epoch 26 loss: 0.001

[I 2024-12-30 15:04:27,416] Trial 9 finished with value: 0.0011574472818109724 and parameters: {'lr': 0.005551990087504371, 'decay': 0.7834123529246597}. Best is trial 9 with value: 0.0011574472818109724.


Epoch 29 loss: 0.0011574472818109724
Epoch 0 loss: 0.0636355761024687
Epoch 1 loss: 0.031351566314697266
Epoch 2 loss: 0.018599496119552188
Epoch 3 loss: 0.012339744199481275
Epoch 4 loss: 0.01159827493959003
Epoch 5 loss: 0.010680551123287942
Epoch 6 loss: 0.009842776382962862
Epoch 7 loss: 0.009148856004079184
Epoch 8 loss: 0.008994325271083249
Epoch 9 loss: 0.008789121587243345
Epoch 10 loss: 0.008652645266718335
Epoch 11 loss: 0.008477320584158102
Epoch 12 loss: 0.00844869731614987
Epoch 13 loss: 0.008383675685359372
Epoch 14 loss: 0.008340236006511582
Epoch 15 loss: 0.008280024863779545
Epoch 16 loss: 0.008263207040727139
Epoch 17 loss: 0.008245172703431712
Epoch 18 loss: 0.008224173759420713
Epoch 19 loss: 0.008207325409683917
Epoch 20 loss: 0.008201555969814459
Epoch 21 loss: 0.008193967760437064
Epoch 22 loss: 0.008192367955214448
Epoch 23 loss: 0.008184544742107391
Epoch 24 loss: 0.008181578376226954


[I 2024-12-30 15:40:27,851] Trial 10 pruned. 


Epoch 25 loss: 0.00817828563352426
Epoch 0 loss: 0.002548650916044911
Epoch 1 loss: 0.0015204038289893004
Epoch 2 loss: 0.0014536830130964518
Epoch 3 loss: 0.0013247893140133885
Epoch 4 loss: 0.0013168415235769418
Epoch 5 loss: 0.0012680907578517993
Epoch 6 loss: 0.0012674243060044115
Epoch 7 loss: 0.0012376239368071158
Epoch 8 loss: 0.0012534467063637243
Epoch 9 loss: 0.0011989697813987732
Epoch 10 loss: 0.0012212698994618324
Epoch 11 loss: 0.0012115536438715127
Epoch 12 loss: 0.001230463191556434
Epoch 13 loss: 0.0012178942043748167
Epoch 14 loss: 0.0011936822750916083
Epoch 15 loss: 0.00121944321371201
Epoch 16 loss: 0.001185157659670545
Epoch 17 loss: 0.0012083734489149517
Epoch 18 loss: 0.0012094286761970983
Epoch 19 loss: 0.0012080181881578432
Epoch 20 loss: 0.0012048023701128033
Epoch 21 loss: 0.0012054163833252257
Epoch 22 loss: 0.0012034817158968912
Epoch 23 loss: 0.0012082884269249109
Epoch 24 loss: 0.001198284172763427
Epoch 25 loss: 0.0012029361419586672
Epoch 26 loss: 0.00

[I 2024-12-30 16:10:54,255] Trial 11 finished with value: 0.0012119046101967494 and parameters: {'lr': 0.009468346119167576, 'decay': 0.8065984387396006}. Best is trial 9 with value: 0.0011574472818109724.


Epoch 29 loss: 0.0012119046101967494
Epoch 0 loss: 0.0031825969668312203
Epoch 1 loss: 0.0016667084483843711
Epoch 2 loss: 0.0014842938010891278
Epoch 3 loss: 0.001406044893277188
Epoch 4 loss: 0.0013566331989649269
Epoch 5 loss: 0.001356084507683085
Epoch 6 loss: 0.0012918708121611013
Epoch 7 loss: 0.0012603185734608108
Epoch 8 loss: 0.0012524734095980723
Epoch 9 loss: 0.001263446834248801
Epoch 10 loss: 0.0012447252714385588
Epoch 11 loss: 0.0012429477517596548
Epoch 12 loss: 0.0012111689688430892
Epoch 13 loss: 0.001242546713910997
Epoch 14 loss: 0.001212845297737254
Epoch 15 loss: 0.0012122316725759043
Epoch 16 loss: 0.0012300401641469863
Epoch 17 loss: 0.0012246422573096221
Epoch 18 loss: 0.0012191887944936752
Epoch 19 loss: 0.001214551027967698
Epoch 20 loss: 0.0012279079399175113
Epoch 21 loss: 0.0012170225123150481
Epoch 22 loss: 0.0012200281360290116
Epoch 23 loss: 0.0012156084930110308
Epoch 24 loss: 0.0012264238773948615
Epoch 25 loss: 0.0012220314755621883
Epoch 26 loss: 0.

[I 2024-12-30 16:40:21,801] Trial 12 finished with value: 0.0012202517051870625 and parameters: {'lr': 0.009536306028847988, 'decay': 0.7689268167827322}. Best is trial 9 with value: 0.0011574472818109724.


Epoch 29 loss: 0.0012202517051870625
Epoch 0 loss: 0.004283553817205959
Epoch 1 loss: 0.0020860470831394196
Epoch 2 loss: 0.0016491394360653227
Epoch 3 loss: 0.00145602409934832
Epoch 4 loss: 0.0013419203460216522
Epoch 5 loss: 0.0013056227083628376
Epoch 6 loss: 0.0015295806547833814
Epoch 7 loss: 0.0012149725161078903
Epoch 8 loss: 0.0012130258449663718
Epoch 9 loss: 0.001493973334112929
Epoch 10 loss: 0.0011917426406095426
Epoch 11 loss: 0.0011895516137075094
Epoch 12 loss: 0.0011746768141165376
Epoch 13 loss: 0.0011893852950177258
Epoch 14 loss: 0.00119099041654004
Epoch 15 loss: 0.0011821709987190035
Epoch 16 loss: 0.0011950746411457658
Epoch 17 loss: 0.0011643320839438173
Epoch 18 loss: 0.0011648871263282166
Epoch 19 loss: 0.0011714657044245137
Epoch 20 loss: 0.0011688937407193913
Epoch 21 loss: 0.0011638964836796124
Epoch 22 loss: 0.0011962611719758974
Epoch 23 loss: 0.0011593359134470422
Epoch 24 loss: 0.0011825534877263838
Epoch 25 loss: 0.001156240649935272
Epoch 26 loss: 0.0

[I 2024-12-30 17:09:45,824] Trial 13 finished with value: 0.0011657905350956651 and parameters: {'lr': 0.0026096063451923196, 'decay': 0.5497714600593803}. Best is trial 9 with value: 0.0011574472818109724.


Epoch 29 loss: 0.0011657905350956651
Epoch 0 loss: 0.004487482365220785
Epoch 1 loss: 0.001796971993624336
Epoch 2 loss: 0.0016902048502945239
Epoch 3 loss: 0.00145911466744211
Epoch 4 loss: 0.001348349201079044
Epoch 5 loss: 0.0013644146965816617
Epoch 6 loss: 0.0014046015373120706
Epoch 7 loss: 0.0012314291282867391
Epoch 8 loss: 0.0011510178188069
Epoch 9 loss: 0.0012408006134339506
Epoch 10 loss: 0.0011759605258703232
Epoch 11 loss: 0.0011675529337177675
Epoch 12 loss: 0.001213257330366307
Epoch 13 loss: 0.0011952105236964093
Epoch 14 loss: 0.0011354746012431052
Epoch 15 loss: 0.0011834693788033393
Epoch 16 loss: 0.0011553307219098012
Epoch 17 loss: 0.001131662762620383
Epoch 18 loss: 0.0013113769236952066
Epoch 19 loss: 0.0010872208772020207
Epoch 20 loss: 0.0011378138895250028
Epoch 21 loss: 0.0015075804096543128
Epoch 22 loss: 0.0012620626690073146
Epoch 23 loss: 0.0011597041221749452
Epoch 24 loss: 0.0011706815437517231
Epoch 25 loss: 0.0011806942315565215
Epoch 26 loss: 0.0011

[I 2024-12-30 17:39:10,991] Trial 14 finished with value: 0.0011473140912130475 and parameters: {'lr': 0.0031310830139417182, 'decay': 0.8847056169377736}. Best is trial 14 with value: 0.0011473140912130475.


Epoch 29 loss: 0.0011473140912130475
Best hyperparameters: {'lr': 0.0031310830139417182, 'decay': 0.8847056169377736}


In [3]:
import optuna
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from optuna.pruners import MedianPruner

vis = visual.loss_precision_recall(20, labels, 2.0)
vis.start()
vis.new_trial()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------ #
#        HYPERPARAMETERS         #
# ------------------------------ #
lr = 5.0e-4
decay = 0.9
# dropout = trial.suggest_float("dropout", 0.25, 0.5)
dropout = 0.3
regularization_strength = 1e-3
# alpha = trial.suggest_float("alpha", 0.25, 1.0)
# theta = trial.suggest_float("theta", 0.1, 0.9)
theta = 0.5
# gamma = trial.suggest_float("gamma", 2.0, 5.0)
gamma = 4.0
# Model initialization
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(64, 128, 256, 512),
    strides=(2, 2, 2),
    num_res_units=2,
    dropout=dropout,
).to(device)

num_epochs = 15
batch_size = 16

# ------------------------------ #
#        TRAINING METHODS        #
# ------------------------------ #
weights = torch.tensor([0.0434743, 1.16546, 1.1661, 1.16513, 1.14281, 1.15554, 1.16149]).to(device)  # Example weights for classes
# weights = torch.tensor([1.0,1.0,1.0,1.0,1.0,1.0,1.0])
dice_loss = DiceLoss(to_onehot_y=False, softmax=True, weight=weights).to(device)
focal_loss = FocalLoss(to_onehot_y=False, use_softmax=True, weight=weights, gamma=gamma ).to(device)
# mse_loss = nn.MSELoss(reduction='none')

optimizer = Adam(model.parameters(), lr=lr)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=decay)

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False, num_workers=4)

def add_regularization_loss(model, regularization_strength):
    reg_loss = 0
    for param in model.parameters():
        reg_loss += torch.sum(param ** 2)
    return regularization_strength * reg_loss

# ------------------------------ #
#             TRAIN              #
# ------------------------------ #
for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        input, target = batch['src'].to(device), batch['tgt'].to(device)
        optimizer.zero_grad()
        output = model(input)
        loss = (theta) * dice_loss(output, target) + (1 - theta) * focal_loss(output, target)
        reg_loss = add_regularization_loss(model, regularization_strength)
        loss += reg_loss
        loss.backward()
        optimizer.step()

    scheduler.step()
        
    # ---------- #
    # VALIDATION #
    # ---------- #
    model.eval()
    val_loss = 0
    precision = torch.zeros((7))
    recall = torch.zeros((7))
    with torch.no_grad():
        for batch in val_loader:
            input, target = batch['src'].float().to(device), batch['tgt'].long().to(device)
            output = model(input)
            loss = (theta) * dice_loss(output, target) + (1 - theta) * focal_loss(output, target)
            reg_loss = add_regularization_loss(model, regularization_strength)
            loss += reg_loss
            val_loss += loss.item()
            p, r = metrics.continuous_precision_recall(target.to('cpu'), torch.softmax(output.to('cpu'), dim=1))
            precision += p
            recall += r
    val_loss /= len(val_loader)
    pr = torch.stack([precision, recall], dim=0)
    vis.report(val_loss, pr)
            
    print(f"Epoch {epoch} loss: {val_loss}")
    

Epoch 0 loss: 0.6432971689436171
Epoch 1 loss: 0.5598176187939115
Epoch 2 loss: 0.5357193417019315
Epoch 3 loss: 0.5257072117593553
Epoch 4 loss: 0.525143735938602


KeyboardInterrupt: 

In [4]:
torch.save(model.state_dict(), "HeatNet-1-0.pth")


In [1]:
# Inference
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')
import load
import augment
import os
import torch
from monai.networks.nets import UNet
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
run = 'TS_6_4'
root = load.get_root()
picks = load.get_picks_dict(root)
vol, coords, scales = load.get_run_volume_picks(root, run=run, level=0)
for key in coords.keys():
    coords[key] = np.array(coords[key], dtype=np.int16)
coord_list = []
for key in coords.keys():
    coord_list.append(coords[key])
radii = [ 6,
          6,
          9,
          15,
          13,
          14 ]
params = augment.aug_params
params["final_size"] = (104,104,104)
params["flip_prob"] = 0.0
params["patch_size"] = (104,104,104)
params["rot_prob"] = 0.0

In [5]:
mask = load.create_exponential_heatmap_gpu(6, vol.shape, coord_list, radii).cpu().numpy()

In [35]:
sample = augment.random_augmentation_gpu(vol, 
                                mask,
                                num_samples=1, 
                                aug_params=params
                                )

src = sample[0]["source"].unsqueeze(0).to(device)
tgt = sample[0]["target"].unsqueeze(0).to(device)

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(64, 128, 256, 512),
    strides=(2, 2, 2),
    num_res_units=2,
    dropout=0.1,
).to(device)
model.load_state_dict(torch.load("HeatNet-1-0.pth"))

model.eval()
pred = torch.softmax(model(src), dim=1).to('cpu')
pred = pred.squeeze().to('cpu').detach().numpy()
src = src.to('cpu').squeeze()
tgt = tgt.squeeze(0).to('cpu').detach().numpy()
print(f"src shape {src.shape}")
print(f"tgt shape {tgt.shape}")
print(f"pred shape {pred.shape}")


src shape torch.Size([104, 104, 104])
tgt shape (7, 104, 104, 104)
pred shape (7, 104, 104, 104)


In [36]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact

def plot_cross_section(i):
    vol1 = 1.0 - src
    vol1 = np.zeros(pred[0].shape)
    vol2 = pred[1]
    
    plt.figure(figsize=(15, 5))
    alpha = 0.3

    # Slice at x-coordinate
    plt.subplot(131)
    plt.imshow(vol1[i, :, :], cmap="viridis", alpha=alpha)
    plt.imshow(vol2[i, :, :], cmap="Reds", alpha=alpha)  # Overlay mask with transparency
    plt.title(f'Slice at x={i}')

    # Slice at y-coordinate
    plt.subplot(132)
    plt.imshow(vol1[:, i, :], cmap="viridis", alpha=alpha)
    plt.imshow(vol2[:, i, :], cmap="Blues", alpha=alpha)
    plt.title(f'Slice at y={i}')

    # Slice at z-coordinate
    plt.subplot(133)
    plt.imshow(vol1[:, :, i], cmap="viridis", alpha=alpha)
    plt.imshow(vol2[:, :, i], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at z={i}')

    plt.show()

# Interactive Slider for scrolling through slices
interact(plot_cross_section, i=(0, src.shape[0] - 1))

interactive(children=(IntSlider(value=51, description='i', max=103), Output()), _dom_classes=('widget-interact…

<function __main__.plot_cross_section(i)>