In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler
import h5py
from tqdm import tqdm
import random

# -------------------------
# Squeeze-and-Excitation Block
# -------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1),
            nn.GELU(),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return x * self.fc(x)

# -------------------------
# Attention Gate for Skip Connections
# -------------------------
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Conv2d(F_g, F_int, kernel_size=1, bias=False)
        self.W_x = nn.Conv2d(F_l, F_int, kernel_size=1, bias=False)
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, bias=True),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
    def forward(self, g, x):
        # g: gating signal, x: skip connection
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# -------------------------
# Depthwise Separable Conv (for ASPP)
# -------------------------
class SepConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, padding, dilation=1):
        super().__init__()
        self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size, padding=padding,
                                   dilation=dilation, groups=in_ch, bias=False)
        self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.norm = nn.GroupNorm(8, out_ch)
        self.act = nn.ReLU()
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return self.act(self.norm(x))

# -------------------------
# Multi-Scale Fusion
# -------------------------
class MultiScaleFusion(nn.Module):
    def __init__(self, in_ch, out_ch, scales=(1.0,0.5,0.25)):
        super().__init__()
        self.scales = scales
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
                nn.GroupNorm(8, out_ch), nn.ReLU(),
                SEBlock(out_ch)
            ) for _ in scales
        ])
        self.merge = nn.Sequential(
            nn.Conv2d(len(scales)*out_ch, out_ch, 1, bias=False),
            nn.GroupNorm(8, out_ch), nn.ReLU(),
            SEBlock(out_ch)
        )
    def forward(self, x):
        bs,c,h,w = x.shape
        outs = []
        for s,conv in zip(self.scales,self.convs):
            xi = F.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False) if s!=1.0 else x
            if s!=1.0:
                xi = F.interpolate(xi, size=(h,w), mode='bilinear', align_corners=False)
            outs.append(conv(xi))
        return self.merge(torch.cat(outs, dim=1))

# -------------------------
# ASPP Module with Mixed Kernels
# -------------------------
class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        dilations = [1,6,12,18]
        kernels   = [3,3,5,7]
        self.branches = nn.ModuleList()
        for d,k in zip(dilations,kernels):
            pad = (k//2)*d
            self.branches.append(SepConv(in_ch, out_ch, k, pad, d))
        self.merge = nn.Sequential(
            nn.Conv2d(len(dilations)*out_ch, out_ch, 1, bias=False),
            nn.GroupNorm(8, out_ch), nn.ReLU(),
            SEBlock(out_ch)
        )
    def forward(self,x):
        return self.merge(torch.cat([b(x) for b in self.branches], dim=1))

# -------------------------
# Residual DoubleConv + Strided Downsampling
# -------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 5, padding=2, bias=False)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.act1  = nn.GELU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 5, padding=2, bias=False)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.act2  = nn.ReLU()
        self.res   = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.se    = SEBlock(out_ch)
    def forward(self, x):
        residual = self.res(x)
        x = self.act1(self.norm1(self.conv1(x)))
        x = self.norm2(self.conv2(x))
        x = self.act2(x + residual)
        return self.se(x)

class PoolConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 5, stride=2, padding=2, bias=False),
            nn.GroupNorm(8, out_ch), nn.ReLU()
        )
    def forward(self, x): return self.net(x)

# -------------------------
# Bottleneck Transformer
# -------------------------
class BottleneckTransformer(nn.Module):
    def __init__(self, dim, heads=8, depth=6, mlp_dim=None):
        super().__init__()
        mlp_dim = mlp_dim or dim*4
        layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads,
                                           dim_feedforward=mlp_dim,
                                           activation='relu', norm_first=True)
        self.encoder = nn.TransformerEncoder(layer, depth)
        self.norm    = nn.LayerNorm(dim)
    def forward(self,x):
        bs,c,h,w = x.shape
        x = x.flatten(2).permute(2,0,1)
        x = self.encoder(x)
        x = x.permute(1,2,0).view(bs,c,h,w)
        return self.norm(x.permute(0,2,3,1)).permute(0,3,1,2)

# -------------------------
# Encoder/Decoder Blocks
# -------------------------
class EncoderBlock(nn.Module):
    def __init__(self,in_ch,out_ch,dropout=0.0,pool=True):
        super().__init__()
        self.conv = DoubleConv(in_ch, out_ch)
        self.down = PoolConv(out_ch, out_ch) if pool else None
        self.drop = nn.Dropout2d(dropout) if dropout>0 else None
    def forward(self,x):
        x = self.conv(x)
        if self.drop: x = self.drop(x)
        skip = x
        x = self.down(x) if self.down else x
        return x, skip

class DecoderBlock(nn.Module):
    def __init__(self,in_ch,out_ch,dropout=0.0):
        super().__init__()
        self.up   = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
        self.attn = AttentionGate(F_g=out_ch, F_l=out_ch, F_int=out_ch//2)
        self.conv = DoubleConv(out_ch*2, out_ch)
        self.drop = nn.Dropout2d(dropout) if dropout>0 else None
    def forward(self,x,skip):
        x = self.up(x)
        skip = self.attn(g=x, x=skip)
        x = self.conv(torch.cat([x,skip],dim=1))
        return self.drop(x) if self.drop else x

# -------------------------
# Full UNet with Attention and SE
# -------------------------
class UNetEnhanced(nn.Module):
    def __init__(self,in_ch=1,base=64,depth=4,dropout=0.0,ds=True):
        super().__init__()
        self.encs = nn.ModuleList()
        ch = in_ch
        for i in range(depth):
            out = base*(2**i)
            self.encs.append(EncoderBlock(ch, out, dropout, pool=(i<depth-1)))
            ch = out
        self.ms   = MultiScaleFusion(ch, ch)
        self.aspp = ASPP(ch, ch)
        self.trans= BottleneckTransformer(ch)
        self.decs = nn.ModuleList()
        ds_chs   = []
        for i in reversed(range(depth-1)):
            out = base*(2**i)
            self.decs.append(DecoderBlock(ch, out, dropout))
            ds_chs.append(out)
            ch = out
        self.final_seg = nn.Conv2d(ch, 1, kernel_size=1)
        self.ds = ds
        if ds:
            self.ds_heads = nn.ModuleList([
                nn.Conv2d(c, 1, kernel_size=1) for c in ds_chs
            ])
    def forward(self,x):
        skips = []
        for enc in self.encs:
            x, skip = enc(x)
            skips.append(skip)
        x = self.ms(x)
        x = self.aspp(x)
        x = self.trans(x)
        ds_out = []
        for idx, dec in enumerate(self.decs):
            x = dec(x, skips[-2-idx])
            if self.ds:
                ds_out.append(self.ds_heads[idx](x))
        seg_logits = self.final_seg(x)
        if self.ds:
            ds_out = [F.interpolate(o, size=seg_logits.shape[2:],
                         mode='bilinear', align_corners=False)
                      for o in ds_out]
            return seg_logits, ds_out
        return seg_logits, None

class TverskyF2Loss(nn.Module):
    def forward(self, preds, targets):
        TP = (preds*targets).sum()
        FN = ((1-preds)*targets).sum()
        FP = (preds*(1-targets)).sum()
        return (FN + 1e-6) / (2*TP + FN + 1e-6)

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.9, beta=0.1, gamma=2.0, eps=1e-6):
        super().__init__()
        self.alpha, self.beta, self.gamma, self.eps = alpha, beta, gamma, eps
    def forward(self, preds, targets):
        preds = preds.view(-1)
        targets = targets.view(-1)
        TP = (preds * targets).sum()
        FP = (preds * (1 - targets)).sum()
        FN = ((1 - preds) * targets).sum()
        tversky = (TP + self.eps) / (TP + self.alpha*FN + self.beta*FP + self.eps)
        return torch.pow((1 - tversky), self.gamma)

class ComboLoss(nn.Module):
    def __init__(self, pos_weight=500.0, bce_weight=0.2, dice_weight=0.4, f2_weight=0.4):
        super().__init__()
        self.bce  = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        self.ft   = FocalTverskyLoss(alpha=0.9, beta=0.3, gamma=2.0)
        self.tf2 = TverskyF2Loss()
        self.bw, self.dw, self.fw= bce_weight, dice_weight, f2_weight
    def forward(self, logits, targets):
        bce_loss  = self.bce(logits, targets)
        dice_loss = self.ft(torch.sigmoid(logits), targets)
        f2_loss   = self.tf2(torch.sigmoid(logits), targets)
        return self.bw*bce_loss + self.dw*dice_loss + self.fw*f2_loss


class H5Dataset(Dataset):
    def __init__(self,path): self.path=path; self.f=None
    def __len__(self): return h5py.File(self.path,'r')['x'].shape[0]
    def __getitem__(self,idx):
        if self.f is None:
            self.f = h5py.File(self.path,'r', swmr=True)
        x = torch.from_numpy(self.f['x'][idx]).float()
        y = torch.from_numpy(self.f['y'][idx]).float()
        return x, y

In [8]:
if __name__ == '__main__':
    bs, max_batches, epochs = 5, 500, 200
    train_path = 'train.h5'
    val_path   = 'val.h5'

    # ─── datasets ───
    train_ds = H5Dataset(train_path)
    val_ds   = H5Dataset(val_path)

    # ─── train loader ───
    # pick exactly bs*max_batches random samples from train set
    small_idx = torch.randperm(len(train_ds))[:10]
    small_loader = DataLoader(train_ds, batch_size=5,
                          sampler=SubsetRandomSampler(small_idx))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetEnhanced(in_ch=1, base=32, depth=4, dropout=0.0, ds=True).to(device)
    crit  = ComboLoss(pos_weight=500, bce_weight=0.5, dice_weight=0.5, f2_weight=0.0)
    opt   = torch.optim.Adam(model.parameters(), lr=1e-4)
    sched = torch.optim.lr_scheduler.ConstantLR(opt, factor=1.0, total_iters=epochs)

    ds_weights = [1, 1, 1]

    for epoch in range(1, epochs+1):
        # ——— train ———
        model.train()
        total_loss = 0.0
        for batch_num, (imgs, msk) in enumerate(pbar:=tqdm(small_loader, desc=f"Epoch {epoch}")):
            #imgs[msk==1] = imgs[msk==1] * 3
            imgs = imgs/50
            imgs = imgs.to(device)
            msk  = msk.to(device)

            seg_logits, ds_out = model(imgs)
            loss = crit(seg_logits, msk)

            # deep supervision
            for aux, w in zip(ds_out, ds_weights):
                loss = loss + w * crit(aux, msk)

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss.item()
            pbar.set_postfix(avg_train_loss=total_loss/(batch_num+1), batch_loss=loss.item(), lr=sched.get_last_lr()[0])
        sched.step()

        # ——— validate ———
        model.eval()
        with torch.no_grad():
            tp=fp=fn=0
            loss = 0
            for imgs, msk in small_loader:
                #imgs[msk==1] = imgs[msk==1] * 3
                imgs = imgs/50
                imgs = imgs.to(device)
                msk  = msk.to(device)

                seg_logits, ds_out = model(imgs)
                loss += crit(seg_logits, msk)

                # deep supervision
                for aux, w in zip(ds_out, ds_weights):
                    loss = loss + w * crit(aux, msk)
                preds = (torch.sigmoid(seg_logits)>0.5).float().view(-1)
                t     = msk.view(-1)

                tp += (preds * t).sum().item()
                fp += (preds * (1-t)).sum().item()
                fn += ((1-preds)*t).sum().item()

            prec = tp/(tp+fp+1e-8)
            rec  = tp/(tp+fn+1e-8)
            f1   = 2*prec*rec/(prec+rec+1e-8)
            print(f"Val loss: {loss:.4f}  Val F1: {f1:.4f}  (P={prec:.4f}, R={rec:.4f})")


Epoch 1: 100%|██████████| 2/2 [00:00<00:00, 10.68it/s, avg_train_loss=5.53, batch_loss=6.02, lr=0.0001]


Val loss: 10.5824  Val F1: 0.0125  (P=0.0063, R=0.6698)


Epoch 2: 100%|██████████| 2/2 [00:00<00:00, 43.90it/s, avg_train_loss=5.26, batch_loss=4.91, lr=0.0001]


Val loss: 10.0373  Val F1: 0.0122  (P=0.0062, R=0.8862)


Epoch 3: 100%|██████████| 2/2 [00:00<00:00, 45.18it/s, avg_train_loss=4.99, batch_loss=5.26, lr=0.0001]


Val loss: 9.5142  Val F1: 0.0135  (P=0.0068, R=0.9198)


Epoch 4: 100%|██████████| 2/2 [00:00<00:00, 43.72it/s, avg_train_loss=4.71, batch_loss=5.01, lr=0.0001]


Val loss: 9.0855  Val F1: 0.0161  (P=0.0081, R=0.9254)


Epoch 5: 100%|██████████| 2/2 [00:00<00:00, 44.66it/s, avg_train_loss=4.53, batch_loss=4.1, lr=0.0001]


Val loss: 8.6253  Val F1: 0.0183  (P=0.0092, R=0.9459)


Epoch 6: 100%|██████████| 2/2 [00:00<00:00, 44.43it/s, avg_train_loss=4.29, batch_loss=3.88, lr=0.0001]


Val loss: 8.1945  Val F1: 0.0207  (P=0.0105, R=0.9590)


Epoch 7: 100%|██████████| 2/2 [00:00<00:00, 44.27it/s, avg_train_loss=4.09, batch_loss=3.74, lr=0.0001]


Val loss: 7.6854  Val F1: 0.0251  (P=0.0127, R=0.9832)


Epoch 8: 100%|██████████| 2/2 [00:00<00:00, 44.63it/s, avg_train_loss=3.78, batch_loss=3.49, lr=0.0001]


Val loss: 7.2002  Val F1: 0.0245  (P=0.0124, R=0.9851)


Epoch 9: 100%|██████████| 2/2 [00:00<00:00, 44.28it/s, avg_train_loss=3.58, batch_loss=3.74, lr=0.0001]


Val loss: 6.7906  Val F1: 0.0351  (P=0.0179, R=0.9907)


Epoch 10: 100%|██████████| 2/2 [00:00<00:00, 45.94it/s, avg_train_loss=3.43, batch_loss=3.51, lr=0.0001]


Val loss: 6.5642  Val F1: 0.0422  (P=0.0215, R=0.9925)


Epoch 11: 100%|██████████| 2/2 [00:00<00:00, 38.42it/s, avg_train_loss=3.2, batch_loss=3.01, lr=0.0001]


Val loss: 6.7223  Val F1: 0.0582  (P=0.0300, R=0.9963)


Epoch 12: 100%|██████████| 2/2 [00:00<00:00, 38.36it/s, avg_train_loss=3.26, batch_loss=2.98, lr=0.0001]


Val loss: 6.0594  Val F1: 0.0771  (P=0.0401, R=0.9981)


Epoch 13: 100%|██████████| 2/2 [00:00<00:00, 44.72it/s, avg_train_loss=3.05, batch_loss=3.09, lr=0.0001]


Val loss: 6.0841  Val F1: 0.0853  (P=0.0446, R=0.9963)


Epoch 14: 100%|██████████| 2/2 [00:00<00:00, 42.90it/s, avg_train_loss=3.01, batch_loss=2.97, lr=0.0001]


Val loss: 5.7926  Val F1: 0.1507  (P=0.0815, R=0.9981)


Epoch 15: 100%|██████████| 2/2 [00:00<00:00, 31.37it/s, avg_train_loss=2.92, batch_loss=2.94, lr=0.0001]


Val loss: 5.7393  Val F1: 0.2114  (P=0.1182, R=0.9981)


Epoch 16: 100%|██████████| 2/2 [00:00<00:00, 40.02it/s, avg_train_loss=2.84, batch_loss=2.9, lr=0.0001]


Val loss: 5.5594  Val F1: 0.2061  (P=0.1149, R=1.0000)


Epoch 17: 100%|██████████| 2/2 [00:00<00:00, 40.29it/s, avg_train_loss=2.78, batch_loss=2.79, lr=0.0001]


Val loss: 5.5179  Val F1: 0.1969  (P=0.1092, R=1.0000)


Epoch 18: 100%|██████████| 2/2 [00:00<00:00, 40.19it/s, avg_train_loss=2.75, batch_loss=2.77, lr=0.0001]


Val loss: 5.4171  Val F1: 0.2365  (P=0.1341, R=1.0000)


Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 30.13it/s, avg_train_loss=2.7, batch_loss=2.68, lr=0.0001]


Val loss: 5.3660  Val F1: 0.2806  (P=0.1632, R=1.0000)


Epoch 20: 100%|██████████| 2/2 [00:00<00:00, 33.12it/s, avg_train_loss=2.68, batch_loss=2.64, lr=0.0001]


Val loss: 5.2905  Val F1: 0.2803  (P=0.1630, R=1.0000)


Epoch 21: 100%|██████████| 2/2 [00:00<00:00, 38.62it/s, avg_train_loss=2.64, batch_loss=2.62, lr=0.0001]


Val loss: 5.2428  Val F1: 0.2733  (P=0.1583, R=1.0000)


Epoch 22: 100%|██████████| 2/2 [00:00<00:00, 31.38it/s, avg_train_loss=2.62, batch_loss=2.63, lr=0.0001]


Val loss: 5.2129  Val F1: 0.2793  (P=0.1623, R=1.0000)


Epoch 23: 100%|██████████| 2/2 [00:00<00:00, 30.11it/s, avg_train_loss=2.6, batch_loss=2.6, lr=0.0001]


Val loss: 5.1559  Val F1: 0.3040  (P=0.1793, R=1.0000)


Epoch 24: 100%|██████████| 2/2 [00:00<00:00, 31.35it/s, avg_train_loss=2.58, batch_loss=2.57, lr=0.0001]


Val loss: 5.1472  Val F1: 0.3291  (P=0.1970, R=1.0000)


Epoch 25: 100%|██████████| 2/2 [00:00<00:00, 37.58it/s, avg_train_loss=2.56, batch_loss=2.56, lr=0.0001]


Val loss: 5.0873  Val F1: 0.3255  (P=0.1944, R=1.0000)


Epoch 26: 100%|██████████| 2/2 [00:00<00:00, 37.44it/s, avg_train_loss=2.54, batch_loss=2.54, lr=0.0001]


Val loss: 5.0719  Val F1: 0.3212  (P=0.1913, R=1.0000)


Epoch 27: 100%|██████████| 2/2 [00:00<00:00, 36.58it/s, avg_train_loss=2.54, batch_loss=2.53, lr=0.0001]


Val loss: 5.0320  Val F1: 0.3355  (P=0.2016, R=1.0000)


Epoch 28: 100%|██████████| 2/2 [00:00<00:00, 37.58it/s, avg_train_loss=2.52, batch_loss=2.51, lr=0.0001]


Val loss: 5.0060  Val F1: 0.3455  (P=0.2088, R=1.0000)


Epoch 29: 100%|██████████| 2/2 [00:00<00:00, 34.46it/s, avg_train_loss=2.52, batch_loss=2.5, lr=0.0001]


Val loss: 4.9881  Val F1: 0.3433  (P=0.2072, R=1.0000)


Epoch 30: 100%|██████████| 2/2 [00:00<00:00, 37.91it/s, avg_train_loss=2.49, batch_loss=2.48, lr=0.0001]


Val loss: 4.9482  Val F1: 0.3519  (P=0.2135, R=1.0000)


Epoch 31: 100%|██████████| 2/2 [00:00<00:00, 34.83it/s, avg_train_loss=2.47, batch_loss=2.47, lr=0.0001]


Val loss: 4.9206  Val F1: 0.3612  (P=0.2204, R=1.0000)


Epoch 32: 100%|██████████| 2/2 [00:00<00:00, 32.63it/s, avg_train_loss=2.46, batch_loss=2.44, lr=0.0001]


Val loss: 4.9225  Val F1: 0.3614  (P=0.2206, R=1.0000)


Epoch 33: 100%|██████████| 2/2 [00:00<00:00, 32.23it/s, avg_train_loss=2.44, batch_loss=2.44, lr=0.0001]


Val loss: 4.8666  Val F1: 0.3654  (P=0.2235, R=1.0000)


Epoch 34: 100%|██████████| 2/2 [00:00<00:00, 37.67it/s, avg_train_loss=2.43, batch_loss=2.43, lr=0.0001]


Val loss: 4.8865  Val F1: 0.3747  (P=0.2305, R=1.0000)


Epoch 35: 100%|██████████| 2/2 [00:00<00:00, 31.50it/s, avg_train_loss=2.42, batch_loss=2.44, lr=0.0001]


Val loss: 4.8105  Val F1: 0.3746  (P=0.2304, R=1.0000)


Epoch 36: 100%|██████████| 2/2 [00:00<00:00, 32.24it/s, avg_train_loss=2.4, batch_loss=2.38, lr=0.0001]


Val loss: 4.7947  Val F1: 0.3860  (P=0.2392, R=1.0000)


Epoch 37: 100%|██████████| 2/2 [00:00<00:00, 37.83it/s, avg_train_loss=2.39, batch_loss=2.41, lr=0.0001]


Val loss: 4.7526  Val F1: 0.3860  (P=0.2392, R=1.0000)


Epoch 38: 100%|██████████| 2/2 [00:00<00:00, 30.26it/s, avg_train_loss=2.37, batch_loss=2.38, lr=0.0001]


Val loss: 4.7084  Val F1: 0.3910  (P=0.2430, R=1.0000)


Epoch 39: 100%|██████████| 2/2 [00:00<00:00, 30.87it/s, avg_train_loss=2.35, batch_loss=2.36, lr=0.0001]


Val loss: 4.6744  Val F1: 0.3938  (P=0.2452, R=1.0000)


Epoch 40: 100%|██████████| 2/2 [00:00<00:00, 37.49it/s, avg_train_loss=2.35, batch_loss=2.4, lr=0.0001]


Val loss: 4.6493  Val F1: 0.3931  (P=0.2446, R=1.0000)


Epoch 41: 100%|██████████| 2/2 [00:00<00:00, 30.93it/s, avg_train_loss=2.32, batch_loss=2.29, lr=0.0001]


Val loss: 4.6461  Val F1: 0.4158  (P=0.2625, R=1.0000)


Epoch 42: 100%|██████████| 2/2 [00:00<00:00, 31.57it/s, avg_train_loss=2.3, batch_loss=2.31, lr=0.0001]


Val loss: 4.5585  Val F1: 0.3917  (P=0.2435, R=1.0000)


Epoch 43: 100%|██████████| 2/2 [00:00<00:00, 37.21it/s, avg_train_loss=2.29, batch_loss=2.3, lr=0.0001]


Val loss: 4.5749  Val F1: 0.3757  (P=0.2313, R=1.0000)


Epoch 44: 100%|██████████| 2/2 [00:00<00:00, 34.16it/s, avg_train_loss=2.29, batch_loss=2.29, lr=0.0001]


Val loss: 4.5222  Val F1: 0.4257  (P=0.2704, R=1.0000)


Epoch 45: 100%|██████████| 2/2 [00:00<00:00, 31.94it/s, avg_train_loss=2.26, batch_loss=2.26, lr=0.0001]


Val loss: 4.5410  Val F1: 0.3573  (P=0.2175, R=1.0000)


Epoch 46: 100%|██████████| 2/2 [00:00<00:00, 29.70it/s, avg_train_loss=2.28, batch_loss=2.21, lr=0.0001]


Val loss: 4.4926  Val F1: 0.3852  (P=0.2385, R=1.0000)


Epoch 47: 100%|██████████| 2/2 [00:00<00:00, 31.45it/s, avg_train_loss=2.2, batch_loss=2.18, lr=0.0001]


Val loss: 4.4334  Val F1: 0.4485  (P=0.2891, R=1.0000)


Epoch 48: 100%|██████████| 2/2 [00:00<00:00, 31.19it/s, avg_train_loss=2.19, batch_loss=2.2, lr=0.0001]


Val loss: 4.3765  Val F1: 0.3700  (P=0.2270, R=1.0000)


Epoch 49: 100%|██████████| 2/2 [00:00<00:00, 30.41it/s, avg_train_loss=2.19, batch_loss=2.18, lr=0.0001]


Val loss: 4.3117  Val F1: 0.3764  (P=0.2318, R=1.0000)


Epoch 50: 100%|██████████| 2/2 [00:00<00:00, 36.56it/s, avg_train_loss=2.17, batch_loss=2.19, lr=0.0001]


Val loss: 4.4273  Val F1: 0.3801  (P=0.2347, R=1.0000)


Epoch 51: 100%|██████████| 2/2 [00:00<00:00, 35.24it/s, avg_train_loss=2.11, batch_loss=2.09, lr=0.0001]


Val loss: 4.1767  Val F1: 0.4367  (P=0.2793, R=1.0000)


Epoch 52: 100%|██████████| 2/2 [00:00<00:00, 34.91it/s, avg_train_loss=2.09, batch_loss=2.08, lr=0.0001]


Val loss: 4.2390  Val F1: 0.4061  (P=0.2548, R=1.0000)


Epoch 53: 100%|██████████| 2/2 [00:00<00:00, 36.16it/s, avg_train_loss=2.08, batch_loss=2.04, lr=0.0001]


Val loss: 4.0875  Val F1: 0.4413  (P=0.2831, R=1.0000)


Epoch 54: 100%|██████████| 2/2 [00:00<00:00, 34.23it/s, avg_train_loss=2.04, batch_loss=2.07, lr=0.0001]


Val loss: 4.0728  Val F1: 0.4044  (P=0.2534, R=1.0000)


Epoch 55: 100%|██████████| 2/2 [00:00<00:00, 35.84it/s, avg_train_loss=2.03, batch_loss=2.06, lr=0.0001]


Val loss: 4.0276  Val F1: 0.4271  (P=0.2715, R=1.0000)


Epoch 56: 100%|██████████| 2/2 [00:00<00:00, 31.27it/s, avg_train_loss=1.99, batch_loss=1.99, lr=0.0001]


Val loss: 3.9605  Val F1: 0.4115  (P=0.2591, R=1.0000)


Epoch 57: 100%|██████████| 2/2 [00:00<00:00, 34.68it/s, avg_train_loss=1.99, batch_loss=2, lr=0.0001]


Val loss: 3.9010  Val F1: 0.4402  (P=0.2823, R=1.0000)


Epoch 58: 100%|██████████| 2/2 [00:00<00:00, 34.29it/s, avg_train_loss=2.07, batch_loss=2.11, lr=0.0001]


Val loss: 4.2564  Val F1: 0.3012  (P=0.1773, R=1.0000)


Epoch 59: 100%|██████████| 2/2 [00:00<00:00, 34.90it/s, avg_train_loss=2.22, batch_loss=2.26, lr=0.0001]


Val loss: 4.5171  Val F1: 0.2594  (P=0.1491, R=1.0000)


Epoch 60: 100%|██████████| 2/2 [00:00<00:00, 34.85it/s, avg_train_loss=2.14, batch_loss=2.07, lr=0.0001]


Val loss: 5.8966  Val F1: 0.5546  (P=0.3946, R=0.9328)


Epoch 61: 100%|██████████| 2/2 [00:00<00:00, 35.85it/s, avg_train_loss=2.28, batch_loss=2, lr=0.0001]


Val loss: 4.1865  Val F1: 0.2979  (P=0.1750, R=1.0000)


Epoch 62: 100%|██████████| 2/2 [00:00<00:00, 34.30it/s, avg_train_loss=2.13, batch_loss=2.19, lr=0.0001]


Val loss: 4.2238  Val F1: 0.3283  (P=0.1964, R=1.0000)


Epoch 63: 100%|██████████| 2/2 [00:00<00:00, 35.69it/s, avg_train_loss=2.1, batch_loss=2, lr=0.0001]


Val loss: 4.2204  Val F1: 0.4131  (P=0.2603, R=1.0000)


Epoch 64: 100%|██████████| 2/2 [00:00<00:00, 34.20it/s, avg_train_loss=2.13, batch_loss=2.24, lr=0.0001]


Val loss: 4.2165  Val F1: 0.3664  (P=0.2243, R=1.0000)


Epoch 65: 100%|██████████| 2/2 [00:00<00:00, 36.22it/s, avg_train_loss=2.22, batch_loss=2.41, lr=0.0001]


Val loss: 4.5183  Val F1: 0.2233  (P=0.1257, R=1.0000)


Epoch 66: 100%|██████████| 2/2 [00:00<00:00, 31.88it/s, avg_train_loss=2.24, batch_loss=2.28, lr=0.0001]


Val loss: 4.1469  Val F1: 0.4101  (P=0.2579, R=1.0000)


Epoch 67: 100%|██████████| 2/2 [00:00<00:00, 33.43it/s, avg_train_loss=2.05, batch_loss=2.03, lr=0.0001]


Val loss: 4.1205  Val F1: 0.5146  (P=0.3465, R=1.0000)


Epoch 68: 100%|██████████| 2/2 [00:00<00:00, 30.60it/s, avg_train_loss=2.01, batch_loss=1.96, lr=0.0001]


Val loss: 3.8471  Val F1: 0.4035  (P=0.2527, R=1.0000)


Epoch 69: 100%|██████████| 2/2 [00:00<00:00, 34.96it/s, avg_train_loss=1.92, batch_loss=1.95, lr=0.0001]


Val loss: 3.8554  Val F1: 0.3577  (P=0.2178, R=1.0000)


Epoch 70: 100%|██████████| 2/2 [00:00<00:00, 30.15it/s, avg_train_loss=1.93, batch_loss=1.88, lr=0.0001]


Val loss: 3.8323  Val F1: 0.3589  (P=0.2187, R=1.0000)


Epoch 71: 100%|██████████| 2/2 [00:00<00:00, 31.12it/s, avg_train_loss=1.92, batch_loss=1.91, lr=0.0001]


Val loss: 3.7479  Val F1: 0.3651  (P=0.2233, R=1.0000)


Epoch 72: 100%|██████████| 2/2 [00:00<00:00, 29.93it/s, avg_train_loss=1.92, batch_loss=1.83, lr=0.0001]


Val loss: 3.6842  Val F1: 0.3838  (P=0.2375, R=1.0000)


Epoch 73: 100%|██████████| 2/2 [00:00<00:00, 35.96it/s, avg_train_loss=1.83, batch_loss=1.84, lr=0.0001]


Val loss: 3.6161  Val F1: 0.4219  (P=0.2673, R=1.0000)


Epoch 74: 100%|██████████| 2/2 [00:00<00:00, 31.56it/s, avg_train_loss=1.81, batch_loss=1.8, lr=0.0001]


Val loss: 3.7247  Val F1: 0.4444  (P=0.2857, R=1.0000)


Epoch 75: 100%|██████████| 2/2 [00:00<00:00, 29.96it/s, avg_train_loss=1.79, batch_loss=1.75, lr=0.0001]


Val loss: 3.5272  Val F1: 0.4424  (P=0.2840, R=1.0000)


Epoch 76: 100%|██████████| 2/2 [00:00<00:00, 30.31it/s, avg_train_loss=1.77, batch_loss=1.78, lr=0.0001]


Val loss: 3.4777  Val F1: 0.4450  (P=0.2862, R=1.0000)


Epoch 77: 100%|██████████| 2/2 [00:00<00:00, 34.72it/s, avg_train_loss=1.74, batch_loss=1.76, lr=0.0001]


Val loss: 3.4446  Val F1: 0.4552  (P=0.2947, R=1.0000)


Epoch 78: 100%|██████████| 2/2 [00:00<00:00, 32.47it/s, avg_train_loss=1.72, batch_loss=1.78, lr=0.0001]


Val loss: 3.4040  Val F1: 0.4639  (P=0.3020, R=1.0000)


Epoch 79: 100%|██████████| 2/2 [00:00<00:00, 38.49it/s, avg_train_loss=1.69, batch_loss=1.74, lr=0.0001]


Val loss: 3.2878  Val F1: 0.4419  (P=0.2836, R=1.0000)


Epoch 80: 100%|██████████| 2/2 [00:00<00:00, 32.10it/s, avg_train_loss=1.64, batch_loss=1.63, lr=0.0001]


Val loss: 3.3091  Val F1: 0.4448  (P=0.2860, R=1.0000)


Epoch 81: 100%|██████████| 2/2 [00:00<00:00, 34.78it/s, avg_train_loss=1.62, batch_loss=1.65, lr=0.0001]


Val loss: 3.1725  Val F1: 0.4685  (P=0.3059, R=1.0000)


Epoch 82: 100%|██████████| 2/2 [00:00<00:00, 35.05it/s, avg_train_loss=1.58, batch_loss=1.57, lr=0.0001]


Val loss: 3.0908  Val F1: 0.4599  (P=0.2986, R=1.0000)


Epoch 83: 100%|██████████| 2/2 [00:00<00:00, 35.51it/s, avg_train_loss=1.55, batch_loss=1.53, lr=0.0001]


Val loss: 3.0833  Val F1: 0.4700  (P=0.3072, R=1.0000)


Epoch 84: 100%|██████████| 2/2 [00:00<00:00, 30.44it/s, avg_train_loss=1.51, batch_loss=1.5, lr=0.0001]


Val loss: 3.0741  Val F1: 0.5184  (P=0.3499, R=1.0000)


Epoch 85: 100%|██████████| 2/2 [00:00<00:00, 29.91it/s, avg_train_loss=1.48, batch_loss=1.45, lr=0.0001]


Val loss: 2.9611  Val F1: 0.4855  (P=0.3206, R=1.0000)


Epoch 86: 100%|██████████| 2/2 [00:00<00:00, 30.57it/s, avg_train_loss=1.45, batch_loss=1.45, lr=0.0001]


Val loss: 2.8430  Val F1: 0.4681  (P=0.3056, R=1.0000)


Epoch 87: 100%|██████████| 2/2 [00:00<00:00, 37.44it/s, avg_train_loss=1.44, batch_loss=1.48, lr=0.0001]


Val loss: 2.7677  Val F1: 0.5149  (P=0.3467, R=1.0000)


Epoch 88: 100%|██████████| 2/2 [00:00<00:00, 31.36it/s, avg_train_loss=1.38, batch_loss=1.36, lr=0.0001]


Val loss: 2.6870  Val F1: 0.4936  (P=0.3276, R=1.0000)


Epoch 89: 100%|██████████| 2/2 [00:00<00:00, 30.30it/s, avg_train_loss=1.35, batch_loss=1.36, lr=0.0001]


Val loss: 2.6834  Val F1: 0.4857  (P=0.3208, R=1.0000)


Epoch 90: 100%|██████████| 2/2 [00:00<00:00, 30.07it/s, avg_train_loss=1.34, batch_loss=1.32, lr=0.0001]


Val loss: 2.5897  Val F1: 0.5242  (P=0.3552, R=1.0000)


Epoch 91: 100%|██████████| 2/2 [00:00<00:00, 35.79it/s, avg_train_loss=1.32, batch_loss=1.31, lr=0.0001]


Val loss: 2.5702  Val F1: 0.5414  (P=0.3712, R=1.0000)


Epoch 92: 100%|██████████| 2/2 [00:00<00:00, 37.41it/s, avg_train_loss=1.3, batch_loss=1.33, lr=0.0001]


Val loss: 2.5371  Val F1: 0.4888  (P=0.3235, R=1.0000)


Epoch 93: 100%|██████████| 2/2 [00:00<00:00, 35.46it/s, avg_train_loss=1.27, batch_loss=1.25, lr=0.0001]


Val loss: 2.5045  Val F1: 0.5033  (P=0.3363, R=1.0000)


Epoch 94: 100%|██████████| 2/2 [00:00<00:00, 29.71it/s, avg_train_loss=1.25, batch_loss=1.28, lr=0.0001]


Val loss: 2.4244  Val F1: 0.5245  (P=0.3554, R=1.0000)


Epoch 95: 100%|██████████| 2/2 [00:00<00:00, 33.96it/s, avg_train_loss=1.21, batch_loss=1.18, lr=0.0001]


Val loss: 2.3944  Val F1: 0.5333  (P=0.3636, R=1.0000)


Epoch 96: 100%|██████████| 2/2 [00:00<00:00, 31.55it/s, avg_train_loss=1.21, batch_loss=1.25, lr=0.0001]


Val loss: 2.3696  Val F1: 0.5616  (P=0.3904, R=1.0000)


Epoch 97: 100%|██████████| 2/2 [00:00<00:00, 29.76it/s, avg_train_loss=1.18, batch_loss=1.18, lr=0.0001]


Val loss: 2.3538  Val F1: 0.5660  (P=0.3947, R=1.0000)


Epoch 98: 100%|██████████| 2/2 [00:00<00:00, 35.57it/s, avg_train_loss=1.16, batch_loss=1.14, lr=0.0001]


Val loss: 2.3953  Val F1: 0.5439  (P=0.3735, R=1.0000)


Epoch 99: 100%|██████████| 2/2 [00:00<00:00, 34.58it/s, avg_train_loss=1.14, batch_loss=1.14, lr=0.0001]


Val loss: 2.2501  Val F1: 0.5384  (P=0.3684, R=1.0000)


Epoch 100: 100%|██████████| 2/2 [00:00<00:00, 30.10it/s, avg_train_loss=1.12, batch_loss=1.12, lr=0.0001]


Val loss: 2.2056  Val F1: 0.5514  (P=0.3807, R=1.0000)


Epoch 101: 100%|██████████| 2/2 [00:00<00:00, 30.38it/s, avg_train_loss=1.1, batch_loss=1.08, lr=0.0001]


Val loss: 2.1688  Val F1: 0.5678  (P=0.3964, R=1.0000)


Epoch 102: 100%|██████████| 2/2 [00:00<00:00, 32.21it/s, avg_train_loss=1.08, batch_loss=1.1, lr=0.0001]


Val loss: 2.2213  Val F1: 0.5788  (P=0.4073, R=1.0000)


Epoch 103: 100%|██████████| 2/2 [00:00<00:00, 35.99it/s, avg_train_loss=1.07, batch_loss=1.12, lr=0.0001]


Val loss: 2.0836  Val F1: 0.5969  (P=0.4254, R=1.0000)


Epoch 104: 100%|██████████| 2/2 [00:00<00:00, 36.40it/s, avg_train_loss=1.04, batch_loss=1.06, lr=0.0001]


Val loss: 2.0671  Val F1: 0.5975  (P=0.4261, R=1.0000)


Epoch 105: 100%|██████████| 2/2 [00:00<00:00, 36.37it/s, avg_train_loss=1.01, batch_loss=1, lr=0.0001]


Val loss: 1.9888  Val F1: 0.5785  (P=0.4070, R=1.0000)


Epoch 106: 100%|██████████| 2/2 [00:00<00:00, 31.03it/s, avg_train_loss=1.01, batch_loss=1.06, lr=0.0001]


Val loss: 1.9733  Val F1: 0.6006  (P=0.4291, R=1.0000)


Epoch 107: 100%|██████████| 2/2 [00:00<00:00, 30.36it/s, avg_train_loss=0.973, batch_loss=0.971, lr=0.0001]


Val loss: 1.9070  Val F1: 0.5791  (P=0.4076, R=1.0000)


Epoch 108: 100%|██████████| 2/2 [00:00<00:00, 31.37it/s, avg_train_loss=0.964, batch_loss=0.924, lr=0.0001]


Val loss: 1.8802  Val F1: 0.5817  (P=0.4101, R=1.0000)


Epoch 109: 100%|██████████| 2/2 [00:00<00:00, 32.00it/s, avg_train_loss=0.938, batch_loss=0.973, lr=0.0001]


Val loss: 1.8539  Val F1: 0.6347  (P=0.4649, R=1.0000)


Epoch 110: 100%|██████████| 2/2 [00:00<00:00, 33.57it/s, avg_train_loss=0.963, batch_loss=1.09, lr=0.0001]


Val loss: 1.8058  Val F1: 0.6646  (P=0.4977, R=1.0000)


Epoch 111: 100%|██████████| 2/2 [00:00<00:00, 31.22it/s, avg_train_loss=0.99, batch_loss=1.11, lr=0.0001]


Val loss: 1.8594  Val F1: 0.5214  (P=0.3526, R=1.0000)


Epoch 112: 100%|██████████| 2/2 [00:00<00:00, 31.94it/s, avg_train_loss=0.943, batch_loss=0.967, lr=0.0001]


Val loss: 1.8167  Val F1: 0.5083  (P=0.3408, R=1.0000)


Epoch 113: 100%|██████████| 2/2 [00:00<00:00, 37.60it/s, avg_train_loss=0.998, batch_loss=1.14, lr=0.0001]


Val loss: 1.6204  Val F1: 0.6122  (P=0.4412, R=1.0000)


Epoch 114: 100%|██████████| 2/2 [00:00<00:00, 36.20it/s, avg_train_loss=0.877, batch_loss=0.681, lr=0.0001]


Val loss: 1.4739  Val F1: 0.6164  (P=0.4456, R=1.0000)


Epoch 115: 100%|██████████| 2/2 [00:00<00:00, 31.12it/s, avg_train_loss=0.742, batch_loss=0.665, lr=0.0001]


Val loss: 1.3851  Val F1: 0.6370  (P=0.4673, R=1.0000)


Epoch 116: 100%|██████████| 2/2 [00:00<00:00, 33.50it/s, avg_train_loss=0.688, batch_loss=0.721, lr=0.0001]


Val loss: 1.2079  Val F1: 0.6358  (P=0.4661, R=1.0000)


Epoch 117: 100%|██████████| 2/2 [00:00<00:00, 31.89it/s, avg_train_loss=0.646, batch_loss=0.548, lr=0.0001]


Val loss: 1.0697  Val F1: 0.6218  (P=0.4512, R=1.0000)


Epoch 118: 100%|██████████| 2/2 [00:00<00:00, 33.34it/s, avg_train_loss=0.517, batch_loss=0.497, lr=0.0001]


Val loss: 0.9373  Val F1: 0.6084  (P=0.4372, R=1.0000)


Epoch 119: 100%|██████████| 2/2 [00:00<00:00, 34.00it/s, avg_train_loss=0.505, batch_loss=0.591, lr=0.0001]


Val loss: 0.9296  Val F1: 0.6339  (P=0.4641, R=1.0000)


Epoch 120: 100%|██████████| 2/2 [00:00<00:00, 32.43it/s, avg_train_loss=0.4, batch_loss=0.391, lr=0.0001]


Val loss: 0.8075  Val F1: 0.6225  (P=0.4519, R=1.0000)


Epoch 121: 100%|██████████| 2/2 [00:00<00:00, 32.17it/s, avg_train_loss=0.348, batch_loss=0.385, lr=0.0001]


Val loss: 0.6107  Val F1: 0.6287  (P=0.4585, R=1.0000)


Epoch 122: 100%|██████████| 2/2 [00:00<00:00, 32.31it/s, avg_train_loss=0.292, batch_loss=0.314, lr=0.0001]


Val loss: 0.6241  Val F1: 0.6939  (P=0.5312, R=1.0000)


Epoch 123: 100%|██████████| 2/2 [00:00<00:00, 31.41it/s, avg_train_loss=0.249, batch_loss=0.248, lr=0.0001]


Val loss: 0.4199  Val F1: 0.6679  (P=0.5014, R=1.0000)


Epoch 124: 100%|██████████| 2/2 [00:00<00:00, 32.29it/s, avg_train_loss=0.24, batch_loss=0.262, lr=0.0001]


Val loss: 0.5233  Val F1: 0.6136  (P=0.4426, R=1.0000)


Epoch 125: 100%|██████████| 2/2 [00:00<00:00, 36.92it/s, avg_train_loss=0.226, batch_loss=0.252, lr=0.0001]


Val loss: 0.3929  Val F1: 0.6396  (P=0.4702, R=1.0000)


Epoch 126: 100%|██████████| 2/2 [00:00<00:00, 31.47it/s, avg_train_loss=0.187, batch_loss=0.157, lr=0.0001]


Val loss: 0.3520  Val F1: 0.6916  (P=0.5286, R=1.0000)


Epoch 127: 100%|██████████| 2/2 [00:00<00:00, 30.50it/s, avg_train_loss=0.157, batch_loss=0.143, lr=0.0001]


Val loss: 0.2749  Val F1: 0.7002  (P=0.5387, R=1.0000)


Epoch 128: 100%|██████████| 2/2 [00:00<00:00, 33.61it/s, avg_train_loss=0.137, batch_loss=0.128, lr=0.0001]


Val loss: 0.3249  Val F1: 0.7415  (P=0.5899, R=0.9981)


Epoch 129: 100%|██████████| 2/2 [00:00<00:00, 37.33it/s, avg_train_loss=0.145, batch_loss=0.118, lr=0.0001]


Val loss: 0.2308  Val F1: 0.7109  (P=0.5514, R=1.0000)


Epoch 130: 100%|██████████| 2/2 [00:00<00:00, 36.15it/s, avg_train_loss=0.132, batch_loss=0.146, lr=0.0001]


Val loss: 0.2182  Val F1: 0.7229  (P=0.5660, R=1.0000)


Epoch 131: 100%|██████████| 2/2 [00:00<00:00, 37.08it/s, avg_train_loss=0.118, batch_loss=0.132, lr=0.0001]


Val loss: 0.2115  Val F1: 0.7195  (P=0.5618, R=1.0000)


Epoch 132: 100%|██████████| 2/2 [00:00<00:00, 37.95it/s, avg_train_loss=0.122, batch_loss=0.137, lr=0.0001]


Val loss: 0.2080  Val F1: 0.7011  (P=0.5398, R=1.0000)


Epoch 133: 100%|██████████| 2/2 [00:00<00:00, 35.85it/s, avg_train_loss=0.136, batch_loss=0.169, lr=0.0001]


Val loss: 0.2353  Val F1: 0.6948  (P=0.5323, R=1.0000)


Epoch 134: 100%|██████████| 2/2 [00:00<00:00, 36.43it/s, avg_train_loss=0.132, batch_loss=0.102, lr=0.0001]


Val loss: 0.2037  Val F1: 0.7614  (P=0.6147, R=1.0000)


Epoch 135: 100%|██████████| 2/2 [00:00<00:00, 39.59it/s, avg_train_loss=0.135, batch_loss=0.163, lr=0.0001]


Val loss: 0.1933  Val F1: 0.7679  (P=0.6233, R=1.0000)


Epoch 136: 100%|██████████| 2/2 [00:00<00:00, 35.36it/s, avg_train_loss=0.102, batch_loss=0.106, lr=0.0001]


Val loss: 0.1868  Val F1: 0.7403  (P=0.5877, R=1.0000)


Epoch 137: 100%|██████████| 2/2 [00:00<00:00, 30.38it/s, avg_train_loss=0.0984, batch_loss=0.0862, lr=0.0001]


Val loss: 0.1777  Val F1: 0.7576  (P=0.6098, R=1.0000)


Epoch 138: 100%|██████████| 2/2 [00:00<00:00, 31.79it/s, avg_train_loss=0.0912, batch_loss=0.0811, lr=0.0001]


Val loss: 0.1768  Val F1: 0.7825  (P=0.6427, R=1.0000)


Epoch 139: 100%|██████████| 2/2 [00:00<00:00, 30.82it/s, avg_train_loss=0.0845, batch_loss=0.0706, lr=0.0001]


Val loss: 0.1815  Val F1: 0.7842  (P=0.6450, R=1.0000)


Epoch 140: 100%|██████████| 2/2 [00:00<00:00, 38.20it/s, avg_train_loss=0.0793, batch_loss=0.0773, lr=0.0001]


Val loss: 0.1769  Val F1: 0.7768  (P=0.6351, R=1.0000)


Epoch 141: 100%|██████████| 2/2 [00:00<00:00, 35.57it/s, avg_train_loss=0.0781, batch_loss=0.0722, lr=0.0001]


Val loss: 0.1697  Val F1: 0.7762  (P=0.6343, R=1.0000)


Epoch 142: 100%|██████████| 2/2 [00:00<00:00, 43.03it/s, avg_train_loss=0.0768, batch_loss=0.0729, lr=0.0001]


Val loss: 0.1531  Val F1: 0.7911  (P=0.6545, R=1.0000)


Epoch 143: 100%|██████████| 2/2 [00:00<00:00, 29.81it/s, avg_train_loss=0.0733, batch_loss=0.0578, lr=0.0001]


Val loss: 0.1437  Val F1: 0.7911  (P=0.6545, R=1.0000)


Epoch 144: 100%|██████████| 2/2 [00:00<00:00, 33.55it/s, avg_train_loss=0.0714, batch_loss=0.0669, lr=0.0001]


Val loss: 0.1336  Val F1: 0.7906  (P=0.6537, R=1.0000)


Epoch 145: 100%|██████████| 2/2 [00:00<00:00, 31.12it/s, avg_train_loss=0.0703, batch_loss=0.0731, lr=0.0001]


Val loss: 0.1301  Val F1: 0.7947  (P=0.6593, R=1.0000)


Epoch 146: 100%|██████████| 2/2 [00:00<00:00, 37.60it/s, avg_train_loss=0.0659, batch_loss=0.0653, lr=0.0001]


Val loss: 0.1267  Val F1: 0.8024  (P=0.6700, R=1.0000)


Epoch 147: 100%|██████████| 2/2 [00:00<00:00, 35.96it/s, avg_train_loss=0.0641, batch_loss=0.0606, lr=0.0001]


Val loss: 0.1274  Val F1: 0.8072  (P=0.6768, R=1.0000)


Epoch 148: 100%|██████████| 2/2 [00:00<00:00, 30.19it/s, avg_train_loss=0.0626, batch_loss=0.0648, lr=0.0001]


Val loss: 0.1204  Val F1: 0.8158  (P=0.6889, R=1.0000)


Epoch 149: 100%|██████████| 2/2 [00:00<00:00, 32.10it/s, avg_train_loss=0.0607, batch_loss=0.0649, lr=0.0001]


Val loss: 0.1223  Val F1: 0.8196  (P=0.6943, R=1.0000)


Epoch 150: 100%|██████████| 2/2 [00:00<00:00, 29.87it/s, avg_train_loss=0.0606, batch_loss=0.057, lr=0.0001]


Val loss: 0.1131  Val F1: 0.8240  (P=0.7007, R=1.0000)


Epoch 151: 100%|██████████| 2/2 [00:00<00:00, 35.29it/s, avg_train_loss=0.0569, batch_loss=0.0558, lr=0.0001]


Val loss: 0.1398  Val F1: 0.8240  (P=0.7007, R=1.0000)


Epoch 152: 100%|██████████| 2/2 [00:00<00:00, 35.61it/s, avg_train_loss=0.064, batch_loss=0.0634, lr=0.0001]


Val loss: 0.1121  Val F1: 0.8323  (P=0.7128, R=1.0000)


Epoch 153: 100%|██████████| 2/2 [00:00<00:00, 37.51it/s, avg_train_loss=0.0546, batch_loss=0.0494, lr=0.0001]


Val loss: 0.1126  Val F1: 0.8488  (P=0.7373, R=1.0000)


Epoch 154: 100%|██████████| 2/2 [00:00<00:00, 30.96it/s, avg_train_loss=0.0614, batch_loss=0.0557, lr=0.0001]


Val loss: 0.1061  Val F1: 0.8323  (P=0.7128, R=1.0000)


Epoch 155: 100%|██████████| 2/2 [00:00<00:00, 38.03it/s, avg_train_loss=0.0535, batch_loss=0.0589, lr=0.0001]


Val loss: 0.1058  Val F1: 0.8388  (P=0.7224, R=1.0000)


Epoch 156: 100%|██████████| 2/2 [00:00<00:00, 36.83it/s, avg_train_loss=0.0531, batch_loss=0.059, lr=0.0001]


Val loss: 0.1000  Val F1: 0.8474  (P=0.7353, R=1.0000)


Epoch 157: 100%|██████████| 2/2 [00:00<00:00, 36.91it/s, avg_train_loss=0.0499, batch_loss=0.0463, lr=0.0001]


Val loss: 0.0996  Val F1: 0.8401  (P=0.7243, R=1.0000)


Epoch 158: 100%|██████████| 2/2 [00:00<00:00, 36.89it/s, avg_train_loss=0.0492, batch_loss=0.0586, lr=0.0001]


Val loss: 0.0982  Val F1: 0.8508  (P=0.7403, R=1.0000)


Epoch 159: 100%|██████████| 2/2 [00:00<00:00, 37.94it/s, avg_train_loss=0.0496, batch_loss=0.0408, lr=0.0001]


Val loss: 0.0945  Val F1: 0.8576  (P=0.7507, R=1.0000)


Epoch 160: 100%|██████████| 2/2 [00:00<00:00, 31.40it/s, avg_train_loss=0.0476, batch_loss=0.0495, lr=0.0001]


Val loss: 0.1184  Val F1: 0.8631  (P=0.7592, R=1.0000)


Epoch 161: 100%|██████████| 2/2 [00:00<00:00, 31.44it/s, avg_train_loss=0.0471, batch_loss=0.0564, lr=0.0001]


Val loss: 0.0971  Val F1: 0.8562  (P=0.7486, R=1.0000)


Epoch 162: 100%|██████████| 2/2 [00:00<00:00, 29.87it/s, avg_train_loss=0.0462, batch_loss=0.0517, lr=0.0001]


Val loss: 0.0877  Val F1: 0.8549  (P=0.7465, R=1.0000)


Epoch 163: 100%|██████████| 2/2 [00:00<00:00, 31.88it/s, avg_train_loss=0.0454, batch_loss=0.0409, lr=0.0001]


Val loss: 0.0893  Val F1: 0.8680  (P=0.7668, R=1.0000)


Epoch 164: 100%|██████████| 2/2 [00:00<00:00, 32.95it/s, avg_train_loss=0.0498, batch_loss=0.0647, lr=0.0001]


Val loss: 0.0882  Val F1: 0.8631  (P=0.7592, R=1.0000)


Epoch 165: 100%|██████████| 2/2 [00:00<00:00, 36.48it/s, avg_train_loss=0.0453, batch_loss=0.0391, lr=0.0001]


Val loss: 0.0944  Val F1: 0.8349  (P=0.7166, R=1.0000)


Epoch 166: 100%|██████████| 2/2 [00:00<00:00, 38.54it/s, avg_train_loss=0.0487, batch_loss=0.0684, lr=0.0001]


Val loss: 0.0886  Val F1: 0.8576  (P=0.7507, R=1.0000)


Epoch 167: 100%|██████████| 2/2 [00:00<00:00, 31.43it/s, avg_train_loss=0.0478, batch_loss=0.0473, lr=0.0001]


Val loss: 0.1056  Val F1: 0.8652  (P=0.7624, R=1.0000)


Epoch 168: 100%|██████████| 2/2 [00:00<00:00, 40.51it/s, avg_train_loss=0.0651, batch_loss=0.0839, lr=0.0001]


Val loss: 0.1188  Val F1: 0.8140  (P=0.6863, R=1.0000)


Epoch 169: 100%|██████████| 2/2 [00:00<00:00, 35.25it/s, avg_train_loss=0.137, batch_loss=0.227, lr=0.0001]


Val loss: 0.1801  Val F1: 0.7424  (P=0.5903, R=1.0000)


Epoch 170: 100%|██████████| 2/2 [00:00<00:00, 37.82it/s, avg_train_loss=0.103, batch_loss=0.113, lr=0.0001]


Val loss: 2.8412  Val F1: 0.6329  (P=0.4851, R=0.9104)


Epoch 171: 100%|██████████| 2/2 [00:00<00:00, 32.71it/s, avg_train_loss=3.69, batch_loss=7.24, lr=0.0001]


Val loss: 0.5744  Val F1: 0.5826  (P=0.4110, R=1.0000)


Epoch 172: 100%|██████████| 2/2 [00:00<00:00, 32.52it/s, avg_train_loss=0.711, batch_loss=1.19, lr=0.0001]


Val loss: 2.1713  Val F1: 0.4448  (P=0.2886, R=0.9701)


Epoch 173: 100%|██████████| 2/2 [00:00<00:00, 29.97it/s, avg_train_loss=1.37, batch_loss=1.19, lr=0.0001]


Val loss: 2.3280  Val F1: 0.2710  (P=0.1568, R=1.0000)


Epoch 174: 100%|██████████| 2/2 [00:00<00:00, 34.33it/s, avg_train_loss=1.09, batch_loss=1.02, lr=0.0001]


Val loss: 2.3776  Val F1: 0.4297  (P=0.2736, R=1.0000)


Epoch 175: 100%|██████████| 2/2 [00:00<00:00, 38.80it/s, avg_train_loss=1.21, batch_loss=1.23, lr=0.0001]


Val loss: 2.3924  Val F1: 0.2086  (P=0.1165, R=1.0000)


Epoch 176: 100%|██████████| 2/2 [00:00<00:00, 40.10it/s, avg_train_loss=1.05, batch_loss=1.07, lr=0.0001]


Val loss: 2.3673  Val F1: 0.4852  (P=0.3228, R=0.9757)


Epoch 177: 100%|██████████| 2/2 [00:00<00:00, 38.09it/s, avg_train_loss=1, batch_loss=0.706, lr=0.0001]


Val loss: 1.3290  Val F1: 0.3688  (P=0.2261, R=1.0000)


Epoch 178: 100%|██████████| 2/2 [00:00<00:00, 39.41it/s, avg_train_loss=0.721, batch_loss=0.83, lr=0.0001]


Val loss: 1.0836  Val F1: 0.4230  (P=0.2683, R=1.0000)


Epoch 179: 100%|██████████| 2/2 [00:00<00:00, 39.73it/s, avg_train_loss=0.517, batch_loss=0.479, lr=0.0001]


Val loss: 1.8403  Val F1: 0.5682  (P=0.3989, R=0.9869)


Epoch 180: 100%|██████████| 2/2 [00:00<00:00, 39.21it/s, avg_train_loss=0.791, batch_loss=0.655, lr=0.0001]


Val loss: 1.8174  Val F1: 0.2879  (P=0.1682, R=1.0000)


Epoch 181: 100%|██████████| 2/2 [00:00<00:00, 36.93it/s, avg_train_loss=0.912, batch_loss=0.993, lr=0.0001]


Val loss: 1.8221  Val F1: 0.2923  (P=0.1712, R=1.0000)


Epoch 182: 100%|██████████| 2/2 [00:00<00:00, 40.23it/s, avg_train_loss=1.12, batch_loss=0.737, lr=0.0001]


Val loss: 1.1147  Val F1: 0.4192  (P=0.2652, R=1.0000)


Epoch 183: 100%|██████████| 2/2 [00:00<00:00, 32.87it/s, avg_train_loss=0.458, batch_loss=0.42, lr=0.0001]


Val loss: 2.4806  Val F1: 0.6261  (P=0.4823, R=0.8918)


Epoch 184: 100%|██████████| 2/2 [00:00<00:00, 30.03it/s, avg_train_loss=0.674, batch_loss=0.328, lr=0.0001]


Val loss: 0.8376  Val F1: 0.4621  (P=0.3004, R=1.0000)


Epoch 185: 100%|██████████| 2/2 [00:00<00:00, 32.08it/s, avg_train_loss=0.442, batch_loss=0.413, lr=0.0001]


Val loss: 0.9676  Val F1: 0.4307  (P=0.2744, R=1.0000)


Epoch 186: 100%|██████████| 2/2 [00:00<00:00, 30.05it/s, avg_train_loss=0.484, batch_loss=0.564, lr=0.0001]


Val loss: 0.9284  Val F1: 0.4508  (P=0.2910, R=1.0000)


Epoch 187: 100%|██████████| 2/2 [00:00<00:00, 36.44it/s, avg_train_loss=0.41, batch_loss=0.359, lr=0.0001]


Val loss: 0.8435  Val F1: 0.5078  (P=0.3405, R=0.9981)


Epoch 188: 100%|██████████| 2/2 [00:00<00:00, 38.01it/s, avg_train_loss=0.389, batch_loss=0.362, lr=0.0001]


Val loss: 1.0240  Val F1: 0.5743  (P=0.4032, R=0.9981)


Epoch 189: 100%|██████████| 2/2 [00:00<00:00, 41.31it/s, avg_train_loss=0.381, batch_loss=0.223, lr=0.0001]


Val loss: 0.8159  Val F1: 0.5390  (P=0.3689, R=1.0000)


Epoch 190: 100%|██████████| 2/2 [00:00<00:00, 30.43it/s, avg_train_loss=0.286, batch_loss=0.207, lr=0.0001]


Val loss: 0.5281  Val F1: 0.5453  (P=0.3748, R=1.0000)


Epoch 191: 100%|██████████| 2/2 [00:00<00:00, 30.42it/s, avg_train_loss=0.272, batch_loss=0.196, lr=0.0001]


Val loss: 0.4864  Val F1: 0.5779  (P=0.4064, R=1.0000)


Epoch 192: 100%|██████████| 2/2 [00:00<00:00, 35.48it/s, avg_train_loss=0.309, batch_loss=0.265, lr=0.0001]


Val loss: 0.4654  Val F1: 0.6077  (P=0.4365, R=1.0000)


Epoch 193: 100%|██████████| 2/2 [00:00<00:00, 36.66it/s, avg_train_loss=0.248, batch_loss=0.215, lr=0.0001]


Val loss: 0.4056  Val F1: 0.6091  (P=0.4379, R=1.0000)


Epoch 194: 100%|██████████| 2/2 [00:00<00:00, 38.73it/s, avg_train_loss=0.242, batch_loss=0.183, lr=0.0001]


Val loss: 0.3823  Val F1: 0.6236  (P=0.4531, R=1.0000)


Epoch 195: 100%|██████████| 2/2 [00:00<00:00, 38.47it/s, avg_train_loss=0.192, batch_loss=0.161, lr=0.0001]


Val loss: 0.3799  Val F1: 0.6361  (P=0.4668, R=0.9981)


Epoch 196: 100%|██████████| 2/2 [00:00<00:00, 35.56it/s, avg_train_loss=0.187, batch_loss=0.14, lr=0.0001]


Val loss: 0.4195  Val F1: 0.6407  (P=0.4718, R=0.9981)


Epoch 197: 100%|██████████| 2/2 [00:00<00:00, 33.65it/s, avg_train_loss=0.199, batch_loss=0.228, lr=0.0001]


Val loss: 0.3760  Val F1: 0.6477  (P=0.4794, R=0.9981)


Epoch 198: 100%|██████████| 2/2 [00:00<00:00, 31.45it/s, avg_train_loss=0.182, batch_loss=0.18, lr=0.0001]


Val loss: 0.3935  Val F1: 0.6377  (P=0.4681, R=1.0000)


Epoch 199: 100%|██████████| 2/2 [00:00<00:00, 39.40it/s, avg_train_loss=0.174, batch_loss=0.182, lr=0.0001]


Val loss: 0.3385  Val F1: 0.6446  (P=0.4756, R=1.0000)


Epoch 200: 100%|██████████| 2/2 [00:00<00:00, 39.06it/s, avg_train_loss=0.163, batch_loss=0.161, lr=0.0001]


Val loss: 0.3146  Val F1: 0.6621  (P=0.4949, R=1.0000)


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """
    A sequence of two 3x3 convolutions each followed by ReLU activation.
    """
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            #nn.GELU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            #nn.GELU(),
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    """
    A basic U-Net architecture without attention, SE blocks, or deep supervision.
    """
    def __init__(self, in_channels=1, out_channels=1, features=None):
        super(UNet, self).__init__()
        if features is None:
            features = [64, 128, 256, 512]

        # Encoder path
        self.downs = nn.ModuleList()
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Decoder path
        self.ups = nn.ModuleList()
        rev_features = features[::-1]
        for feature in rev_features:
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=3, stride=2))
            self.ups.append(DoubleConv(feature * 2, feature))

        # Final 1x1 conv
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Downsampling
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = F.max_pool2d(x, kernel_size=3, stride=2)

        # Bottleneck
        x = self.bottleneck(x)

        # Reverse skip connections for decoding
        skip_connections = skip_connections[::-1]

        # Upsampling
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)  # ConvTranspose2d
            skip = skip_connections[idx // 2]
            # In case the inexact sizes due to pooling
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
            x = torch.cat((skip, x), dim=1)
            x = self.ups[idx + 1](x)  # DoubleConv

        return self.final_conv(x)


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler
import h5py
from tqdm import tqdm
import random

class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num   = 2 * (probs * targets).sum(dim=(1,2,3))
        den   = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        dice  = num / den
        return 1 - dice.mean()

class BCEDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.5):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        self.bw = bce_weight

    def forward(self, logits, targets):
        return self.bw * self.bce(logits, targets) + (1 - self.bw) * self.dice(logits, targets)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, eps=1e-6):
        super().__init__()
        self.alpha, self.gamma, self.eps = alpha, gamma, eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits).clamp(self.eps, 1 - self.eps)
        pt = torch.where(targets==1, probs, 1 - probs)
        w  = torch.where(targets==1, self.alpha, 1 - self.alpha)
        loss = - w * (1 - pt)**self.gamma * pt.log()
        return loss.mean()

class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, eps=1e-6):
        super().__init__()
        self.alpha, self.beta, self.eps = alpha, beta, eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        TP = (probs * targets).sum(dim=(1,2,3))
        FP = (probs * (1-targets)).sum(dim=(1,2,3))
        FN = ((1-probs) * targets).sum(dim=(1,2,3))
        tversky = (TP + self.eps) / (TP + self.alpha*FN + self.beta*FP + self.eps)
        return 1 - tversky.mean()

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, gamma=1.5, eps=1e-6):
        super().__init__()
        self.tversky = TverskyLoss(alpha, beta, eps)
        self.gamma = gamma

    def forward(self, logits, targets):
        t = 1 - self.tversky(logits, targets)
        return t ** self.gamma

class TverskyF2Loss(nn.Module):
    def forward(self, preds, targets):
        TP = (preds*targets).sum()
        FN = ((1-preds)*targets).sum()
        FP = (preds*(1-targets)).sum()
        return (FN + 1e-6) / (2*TP + FN + 1e-6)

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.9, beta=0.1, gamma=2.0, eps=1e-6):
        super().__init__()
        self.alpha, self.beta, self.gamma, self.eps = alpha, beta, gamma, eps
    def forward(self, preds, targets):
        preds = preds.view(-1)
        targets = targets.view(-1)
        TP = (preds * targets).sum()
        FP = (preds * (1 - targets)).sum()
        FN = ((1 - preds) * targets).sum()
        tversky = (TP + self.eps) / (TP + self.alpha*FN + self.beta*FP + self.eps)
        return torch.pow((1 - tversky), self.gamma)

class ComboLoss(nn.Module):
    def __init__(self, pos_weight=500.0, bce_weight=0.2, dice_weight=0.4, f2_weight=0.4):
        super().__init__()
        self.bce  = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        self.ft   = FocalTverskyLoss(alpha=0.9, beta=0.3, gamma=2.0)
        self.tf2 = TverskyF2Loss()
        self.bw, self.dw, self.fw= bce_weight, dice_weight, f2_weight
    def forward(self, logits, targets):
        bce_loss  = self.bce(logits, targets)
        dice_loss = self.ft(torch.sigmoid(logits), targets)
        f2_loss   = self.tf2(torch.sigmoid(logits), targets)
        return self.bw*bce_loss + self.dw*dice_loss + self.fw*f2_loss

class H5Dataset(Dataset):
    def __init__(self,path): self.path=path; self.f=None
    def __len__(self): return h5py.File(self.path,'r')['x'].shape[0]
    def __getitem__(self,idx):
        if self.f is None:
            self.f = h5py.File(self.path,'r', swmr=True)
        x = torch.from_numpy(self.f['x'][idx]).float()
        y = torch.from_numpy(self.f['y'][idx]).float()
        return x, y

In [None]:
from torch.amp import GradScaler, autocast
if __name__ == '__main__':
    bs, max_batches, epochs = 5, 500, 200
    train_path = 'train.h5'
    val_path   = 'val.h5'

    # ─── datasets ───
    train_ds = H5Dataset(train_path)
    val_ds   = H5Dataset(val_path)

    # ─── train loader ───
    # pick exactly bs*max_batches random samples from train set
    # total small‐loader size and split
    num_samples = 50
    num_pos     = num_samples // 2
    num_neg     = num_samples - num_pos

    # collect all pos/neg indices in order
    pos_idx, neg_idx = [], []
    for i in range(len(train_ds)):
        _, mask = train_ds[i]
        if mask.sum() > 0:
            pos_idx.append(i)
        else:
            neg_idx.append(i)

    # take the first N from each
    selected_pos = pos_idx[:num_pos]
    selected_neg = neg_idx[:num_neg]

    # combine (optionally keep positives first, or interleave)
    balanced_idx = selected_pos + selected_neg

    small_loader = DataLoader(
        train_ds,
        batch_size=5,
        sampler=SubsetRandomSampler(balanced_idx),
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4,
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNet(in_channels=1, out_channels=1).to(device)
    crit  = ComboLoss(pos_weight=100, bce_weight=0.5, dice_weight=0.5, f2_weight=0.0)
    #crit = DiceLoss()
    #crit = BCEDiceLoss(bce_weight=100)
    #crit = FocalLoss(alpha=0.25, gamma=2.0)
    #crit = FocalTverskyLoss(alpha=0.9, beta=0.1, gamma=1)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-2)
    #sched = torch.optim.lr_scheduler.ConstantLR(opt, factor=1.0, total_iters=epochs)
    sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=5e-4, steps_per_epoch=len(small_loader), epochs=epochs)

    ds_weights = [1, 1, 1]
    scaler = GradScaler('cuda')

    for epoch in range(1, epochs+1):
        # ——— train ———
        model.train()
        total_loss = 0.0
        for batch_num, (imgs, msk) in enumerate(pbar:=tqdm(small_loader, desc=f"Epoch {epoch}")):
            #imgs[msk==1] = imgs[msk==1]
            imgs = imgs/50
            imgs = imgs.to(device)
            msk  = msk.to(device)
            opt.zero_grad()
            with autocast('cuda'):
                seg_logits = model(imgs)
                loss = crit(seg_logits, msk)
            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=4.0)
            scaler.step(opt)
            scaler.update()
            #loss.backward()
            #opt.step()

            total_loss += loss.item()
            pbar.set_postfix(avg_train_loss=total_loss/(batch_num+1), batch_loss=loss.item(), lr=sched.get_last_lr()[0])
            sched.step()

        # ——— validate ———
        model.eval()
        with torch.no_grad():
            tp=fp=fn=0
            loss = 0
            for imgs, msk in small_loader:
                #imgs[msk==1] = imgs[msk==1]
                imgs = imgs/50
                imgs = imgs.to(device)
                msk  = msk.to(device)

                seg_logits = model(imgs)
                loss += crit(seg_logits, msk)
                preds = (torch.sigmoid(seg_logits)>0.5).float().view(-1)
                t     = msk.view(-1)

                tp += (preds * t).sum().item()
                fp += (preds * (1-t)).sum().item()
                fn += ((1-preds)*t).sum().item()

            prec = tp/(tp+fp+1e-8)
            rec  = tp/(tp+fn+1e-8)
            f1   = 2*prec*rec/(prec+rec+1e-8)
            print(f"Val loss: {loss:.4f}  Val F1: {f1:.4f}  (P={prec:.4f}, R={rec:.4f})")


Epoch 1: 100%|██████████| 10/10 [00:00<00:00, 31.33it/s, avg_train_loss=0.925, batch_loss=0.891, lr=2.03e-5]


Val loss: 9.1768  Val F1: 0.0039  (P=0.0020, R=0.9819)


Epoch 2: 100%|██████████| 10/10 [00:00<00:00, 49.69it/s, avg_train_loss=0.91, batch_loss=0.894, lr=2.12e-5]


Val loss: 9.1233  Val F1: 0.0051  (P=0.0025, R=0.7938)


Epoch 3: 100%|██████████| 10/10 [00:00<00:00, 51.98it/s, avg_train_loss=0.905, batch_loss=0.973, lr=2.28e-5]


Val loss: 9.0104  Val F1: 0.0067  (P=0.0034, R=0.7402)


Epoch 4: 100%|██████████| 10/10 [00:00<00:00, 51.84it/s, avg_train_loss=0.896, batch_loss=0.89, lr=2.5e-5] 


Val loss: 8.9084  Val F1: 0.0081  (P=0.0041, R=0.7701)


Epoch 5: 100%|██████████| 10/10 [00:00<00:00, 54.30it/s, avg_train_loss=0.884, batch_loss=0.866, lr=2.79e-5]


Val loss: 8.7703  Val F1: 0.0100  (P=0.0050, R=0.8037)


Epoch 6: 100%|██████████| 10/10 [00:00<00:00, 53.87it/s, avg_train_loss=0.875, batch_loss=0.85, lr=3.14e-5]


Val loss: 8.6325  Val F1: 0.0129  (P=0.0065, R=0.8579)


Epoch 7: 100%|██████████| 10/10 [00:00<00:00, 52.32it/s, avg_train_loss=0.862, batch_loss=0.887, lr=3.55e-5]


Val loss: 8.5514  Val F1: 0.0139  (P=0.0070, R=0.9234)


Epoch 8: 100%|██████████| 10/10 [00:00<00:00, 50.26it/s, avg_train_loss=0.854, batch_loss=0.852, lr=4.03e-5]


Val loss: 8.5187  Val F1: 0.0128  (P=0.0064, R=0.9595)


Epoch 9: 100%|██████████| 10/10 [00:00<00:00, 48.62it/s, avg_train_loss=0.837, batch_loss=0.848, lr=4.57e-5]


Val loss: 8.6146  Val F1: 0.0133  (P=0.0067, R=0.9626)


Epoch 10: 100%|██████████| 10/10 [00:00<00:00, 50.23it/s, avg_train_loss=0.822, batch_loss=0.799, lr=5.16e-5]


Val loss: 10.7647  Val F1: 0.0085  (P=0.0043, R=0.9838)


Epoch 11: 100%|██████████| 10/10 [00:00<00:00, 55.31it/s, avg_train_loss=0.797, batch_loss=0.77, lr=5.82e-5]


Val loss: 7.8315  Val F1: 0.0367  (P=0.0187, R=0.9452)


Epoch 12: 100%|██████████| 10/10 [00:00<00:00, 54.95it/s, avg_train_loss=0.771, batch_loss=0.761, lr=6.52e-5]


Val loss: 8.3411  Val F1: 0.0132  (P=0.0067, R=0.9857)


Epoch 13: 100%|██████████| 10/10 [00:00<00:00, 49.35it/s, avg_train_loss=0.747, batch_loss=0.746, lr=7.29e-5]


Val loss: 7.5249  Val F1: 0.0516  (P=0.0265, R=0.9819)


Epoch 14: 100%|██████████| 10/10 [00:00<00:00, 47.86it/s, avg_train_loss=0.744, batch_loss=0.728, lr=8.1e-5]


Val loss: 7.7312  Val F1: 0.0301  (P=0.0153, R=0.9826)


Epoch 15: 100%|██████████| 10/10 [00:00<00:00, 53.05it/s, avg_train_loss=0.723, batch_loss=0.698, lr=8.96e-5]


Val loss: 7.3549  Val F1: 0.0449  (P=0.0230, R=0.9838)


Epoch 16: 100%|██████████| 10/10 [00:00<00:00, 53.49it/s, avg_train_loss=0.726, batch_loss=0.733, lr=9.87e-5]


Val loss: 7.7018  Val F1: 0.0291  (P=0.0148, R=0.9009)


Epoch 17: 100%|██████████| 10/10 [00:00<00:00, 50.57it/s, avg_train_loss=0.715, batch_loss=0.719, lr=0.000108]


Val loss: 7.4047  Val F1: 0.0348  (P=0.0177, R=0.9819)


Epoch 18: 100%|██████████| 10/10 [00:00<00:00, 50.66it/s, avg_train_loss=0.688, batch_loss=0.706, lr=0.000118]


Val loss: 6.9004  Val F1: 0.1088  (P=0.0576, R=0.9782)


Epoch 19: 100%|██████████| 10/10 [00:00<00:00, 48.74it/s, avg_train_loss=0.683, batch_loss=0.683, lr=0.000129]


Val loss: 6.8330  Val F1: 0.1134  (P=0.0601, R=0.9913)


Epoch 20: 100%|██████████| 10/10 [00:00<00:00, 49.18it/s, avg_train_loss=0.683, batch_loss=0.733, lr=0.000139]


Val loss: 6.9412  Val F1: 0.0645  (P=0.0334, R=0.9595)


Epoch 21: 100%|██████████| 10/10 [00:00<00:00, 48.75it/s, avg_train_loss=0.676, batch_loss=0.675, lr=0.00015]


Val loss: 7.3615  Val F1: 0.0448  (P=0.0229, R=0.9819)


Epoch 22: 100%|██████████| 10/10 [00:00<00:00, 47.59it/s, avg_train_loss=0.667, batch_loss=0.662, lr=0.000162]


Val loss: 6.6816  Val F1: 0.1593  (P=0.0869, R=0.9533)


Epoch 23: 100%|██████████| 10/10 [00:00<00:00, 49.54it/s, avg_train_loss=0.659, batch_loss=0.721, lr=0.000173]


Val loss: 6.5399  Val F1: 0.2071  (P=0.1162, R=0.9495)


Epoch 24: 100%|██████████| 10/10 [00:00<00:00, 52.58it/s, avg_train_loss=0.663, batch_loss=0.684, lr=0.000185]


Val loss: 7.2791  Val F1: 0.0384  (P=0.0197, R=0.8231)


Epoch 25: 100%|██████████| 10/10 [00:00<00:00, 54.30it/s, avg_train_loss=0.687, batch_loss=0.727, lr=0.000197]


Val loss: 6.7841  Val F1: 0.0769  (P=0.0401, R=0.8997)


Epoch 26: 100%|██████████| 10/10 [00:00<00:00, 52.34it/s, avg_train_loss=0.644, batch_loss=0.622, lr=0.000209]


Val loss: 7.6985  Val F1: 0.0213  (P=0.0108, R=0.8168)


Epoch 27: 100%|██████████| 10/10 [00:00<00:00, 49.57it/s, avg_train_loss=0.643, batch_loss=0.672, lr=0.000222]


Val loss: 19.8921  Val F1: 0.0053  (P=0.0027, R=0.9919)


Epoch 28: 100%|██████████| 10/10 [00:00<00:00, 48.01it/s, avg_train_loss=0.635, batch_loss=0.64, lr=0.000234]


Val loss: 6.4694  Val F1: 0.1206  (P=0.0646, R=0.9134)


Epoch 29: 100%|██████████| 10/10 [00:00<00:00, 47.44it/s, avg_train_loss=0.638, batch_loss=0.639, lr=0.000247]


Val loss: 8.4009  Val F1: 0.0167  (P=0.0084, R=0.9745)


Epoch 30: 100%|██████████| 10/10 [00:00<00:00, 50.01it/s, avg_train_loss=0.629, batch_loss=0.632, lr=0.000259]


Val loss: 6.0269  Val F1: 0.1956  (P=0.1086, R=0.9888)


Epoch 31: 100%|██████████| 10/10 [00:00<00:00, 46.84it/s, avg_train_loss=0.601, batch_loss=0.56, lr=0.000272]


Val loss: 5.9750  Val F1: 0.4483  (P=0.2900, R=0.9875)


Epoch 32: 100%|██████████| 10/10 [00:00<00:00, 51.45it/s, avg_train_loss=0.59, batch_loss=0.588, lr=0.000285]


Val loss: 5.8761  Val F1: 0.3378  (P=0.2037, R=0.9900)


Epoch 33: 100%|██████████| 10/10 [00:00<00:00, 46.84it/s, avg_train_loss=0.596, batch_loss=0.581, lr=0.000297]


Val loss: 6.4532  Val F1: 0.0802  (P=0.0418, R=0.9626)


Epoch 34: 100%|██████████| 10/10 [00:00<00:00, 52.15it/s, avg_train_loss=0.603, batch_loss=0.638, lr=0.000309]


Val loss: 6.4373  Val F1: 0.1769  (P=0.1025, R=0.6461)


Epoch 35: 100%|██████████| 10/10 [00:00<00:00, 47.62it/s, avg_train_loss=0.592, batch_loss=0.597, lr=0.000322]


Val loss: 8.4272  Val F1: 0.0146  (P=0.0074, R=0.9452)


Epoch 36: 100%|██████████| 10/10 [00:00<00:00, 51.66it/s, avg_train_loss=0.58, batch_loss=0.598, lr=0.000334]


Val loss: 5.8164  Val F1: 0.4561  (P=0.3121, R=0.8474)


Epoch 37: 100%|██████████| 10/10 [00:00<00:00, 49.15it/s, avg_train_loss=0.583, batch_loss=0.592, lr=0.000346]


Val loss: 10.0656  Val F1: 0.0078  (P=0.0039, R=0.9103)


Epoch 38: 100%|██████████| 10/10 [00:00<00:00, 51.31it/s, avg_train_loss=0.569, batch_loss=0.573, lr=0.000357]


Val loss: 11.6360  Val F1: 0.0116  (P=0.0059, R=0.9938)


Epoch 39: 100%|██████████| 10/10 [00:00<00:00, 51.69it/s, avg_train_loss=0.592, batch_loss=0.55, lr=0.000369]


Val loss: 5.9324  Val F1: 0.1050  (P=0.0558, R=0.8847)


Epoch 40: 100%|██████████| 10/10 [00:00<00:00, 53.64it/s, avg_train_loss=0.58, batch_loss=0.585, lr=0.00038] 


Val loss: 5.7674  Val F1: 0.0885  (P=0.0463, R=0.9900)


Epoch 41: 100%|██████████| 10/10 [00:00<00:00, 46.12it/s, avg_train_loss=0.573, batch_loss=0.59, lr=0.00039] 


Val loss: 8.5872  Val F1: 0.0146  (P=0.0074, R=0.9539)


Epoch 42: 100%|██████████| 10/10 [00:00<00:00, 46.51it/s, avg_train_loss=0.552, batch_loss=0.546, lr=0.000401]


Val loss: 5.3670  Val F1: 0.4735  (P=0.3176, R=0.9296)


Epoch 43: 100%|██████████| 10/10 [00:00<00:00, 49.84it/s, avg_train_loss=0.556, batch_loss=0.532, lr=0.000411]


Val loss: 7.4565  Val F1: 0.0231  (P=0.0117, R=0.9869)


Epoch 44: 100%|██████████| 10/10 [00:00<00:00, 52.16it/s, avg_train_loss=0.556, batch_loss=0.511, lr=0.00042]


Val loss: 5.4109  Val F1: 0.3423  (P=0.2106, R=0.9128)


Epoch 45: 100%|██████████| 10/10 [00:00<00:00, 48.13it/s, avg_train_loss=0.512, batch_loss=0.481, lr=0.000429]


Val loss: 5.3122  Val F1: 0.1909  (P=0.1063, R=0.9364)


Epoch 46:  50%|█████     | 5/10 [00:00<00:00, 46.28it/s, avg_train_loss=0.51, batch_loss=0.553, lr=0.000437] 

In [89]:
if __name__ == '__main__':
    bs, max_batches, epochs = 32, 500, 100
    train_path = 'train.h5'
    val_path   = 'val.h5'

    # ─── datasets ───
    train_ds = H5Dataset(train_path)
    val_ds   = H5Dataset(val_path)

    # ─── train loader ───
    # pick exactly bs*max_batches random samples from train set
    tr_idx = torch.randperm(len(train_ds))[: bs * max_batches]
    tr_loader = DataLoader(
        train_ds,
        batch_size=bs,
        sampler=SubsetRandomSampler(tr_idx),
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4,
    )

    # ─── val loader ───
    # just iterate through val.h5 in order (or set shuffle=True if you like)
    val_loader = DataLoader(
        val_ds,
        batch_size=bs,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4,
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNet(in_channels=1, out_channels=1).to(device)
    crit  = ComboLoss(pos_weight=100, bce_weight=0.5, dice_weight=0.5, f2_weight=0.0)
    #crit = DiceLoss()
    #crit = BCEDiceLoss(bce_weight=100)
    #crit = FocalLoss(alpha=0.25, gamma=2.0)
    #crit = FocalTverskyLoss(alpha=0.9, beta=0.1, gamma=1)

    opt   = torch.optim.Adam(model.parameters(), lr=1e-3)
    #sched = torch.optim.lr_scheduler.ConstantLR(opt, factor=1.0, total_iters=epochs)
    sched = torch.optim.lr_scheduler.OneCycleLR(
    opt, max_lr=1e-3,
    steps_per_epoch=len(tr_loader), epochs=epochs
)
    #sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=20, T_mult=3)

    ds_weights = [1, 1, 1]
    scaler = GradScaler('cuda')

    for epoch in range(1, epochs+1):
        # ——— train ———
        model.train()
        total_loss = 0.0
        for batch_num, (imgs, msk) in enumerate(pbar:=tqdm(tr_loader, desc=f"Epoch {epoch}")):
            #imgs[msk==1] = imgs[msk==1] * 3
            imgs = imgs/50
            imgs = imgs.to(device)
            msk  = msk.to(device)
            opt.zero_grad()
            with autocast('cuda'):
                seg_logits = model(imgs)
                loss = crit(seg_logits, msk)
            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=4.0)
            scaler.step(opt)
            scaler.update()
            #loss.backward()
            #opt.step()

            total_loss += loss.item()
            pbar.set_postfix(avg_train_loss=total_loss/(batch_num+1), batch_loss=loss.item(), lr=sched.get_last_lr()[0])
            sched.step()

        # ——— validate ———
        model.eval()
        with torch.no_grad():
            tp=fp=fn=0
            loss = 0
            for imgs, msk in val_loader:
                #imgs[msk==1] = imgs[msk==1] * 3
                imgs = imgs/50
                imgs = imgs.to(device)
                msk  = msk.to(device)

                seg_logits = model(imgs)
                loss += crit(seg_logits, msk)
                preds = (torch.sigmoid(seg_logits)>0.5).float().view(-1)
                t     = msk.view(-1)

                tp += (preds * t).sum().item()
                fp += (preds * (1-t)).sum().item()
                fn += ((1-preds)*t).sum().item()

            prec = tp/(tp+fp+1e-8)
            rec  = tp/(tp+fn+1e-8)
            f1   = 2*prec*rec/(prec+rec+1e-8)
            print(f"Val loss: {loss:.4f}  Val F1: {f1:.4f}  (P={prec:.4f}, R={rec:.4f})")


Epoch 1:   0%|          | 0/469 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.64 GiB. GPU 0 has a total capacity of 15.48 GiB of which 1.59 GiB is free. Process 2857 has 258.00 MiB memory in use. Including non-PyTorch memory, this process has 12.85 GiB memory in use. Of the allocated memory 9.55 GiB is allocated by PyTorch, and 2.96 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)