In [1]:
import json

In [2]:
train_datalist = json.load(open("data_volumes/train_datalist.json"))[:939]
valid_datalist = json.load(open("data_volumes/valid_datalist.json"))

In [3]:
import torch
device = torch.device('cuda')

  import pynvml  # type: ignore[import]


In [4]:
for i in train_datalist:
    i['embeddings'] = f"/home/jupyter/datasphere/project/data_volumes/dataset/embeddings/{i['name'].replace('.nii.gz', '.pt')}"
    
for i in valid_datalist:
    i['embeddings'] = f"/home/jupyter/datasphere/project/data_volumes/dataset/embeddings_val/{i['name'].replace('.nii.gz', '.pt')}"


In [5]:
import torch
from monai.networks.nets import SwinUNETR

device = "cuda" if torch.cuda.is_available() else "cpu"

# создаём модель
model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=2,
    out_channels=43,
    feature_size=48,
    use_checkpoint=True,
).to(device)

# загрузка
state_dict = torch.load('best_model.pth', map_location=device)
model.load_state_dict(state_dict)
model.eval()

Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
2025-10-01 04:06:22.797086: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(2, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): Linear(in_feature

In [6]:
## import os
import torch, os
from monai.networks.nets import SwinUNETR
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Orientationd, 
    Spacingd, ScaleIntensityRanged, EnsureTyped
)
from monai.data import Dataset, DataLoader

# ---------------------
# Настройки
# ---------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
save_dir = "/home/jupyter/datasphere/project/data_volumes/dataset/embeddings"
os.makedirs(save_dir, exist_ok=True)

roi_size = (96, 96, 96)
overlap = 0.25
TARGET_SIZE = (512, 512, 512)

# ---------------------
# Модель (только swinViT часть нужна)
# ---------------------
model = SwinUNETR(
    img_size=roi_size,
    in_channels=2,
    out_channels=43,
    feature_size=48,
    use_checkpoint=True
).to(device)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

# ---------------------
# Трансформы
# ---------------------

class AddGlobalResized(Transform):
    """
    Создаёт уменьшенную копию всего скана и кладёт её под новый ключ.
    """
    def __init__(self, source_key="image", target_key="global", size=96):
        self.source_key = source_key
        self.target_key = target_key
        self.size = size

    def __call__(self, data):
        d = dict(data)
        img = d[self.source_key]
        global_resized = F.interpolate(
            img.unsqueeze(0), size=(self.size, self.size, self.size),
            mode="trilinear", align_corners=False
        ).squeeze(0)  # [C, size, size, size]
        d[self.target_key] = global_resized
        return d


class AppendGlobalChannel(Transform):
    """
    Добавляет глобальный канал (d["global"]) к каждому кропу d["image"].
    Работает и если RandCrop вернул список словарей.
    """
    def __init__(self, image_key="image", global_key="global"):
        self.image_key = image_key
        self.global_key = global_key

    def _add_channel(self, d):
        d = dict(d)
        patch_img = d[self.image_key]
        global_resized = d[self.global_key]
        # конкат по каналам
        d[self.image_key] = torch.cat([patch_img, global_resized], dim=0)
        return d

    def __call__(self, data):
        if isinstance(data, list):
            return [self._add_channel(d) for d in data]
        return self._add_channel(data)

transforms = Compose([
    LoadImaged(keys=["image"], ensure_channel_first=True),
    Orientationd(keys=["image"], axcodes="RAS"),
    Spacingd(keys=["image"], pixdim=(1.0,1.0,1.0), mode="bilinear"),
    ScaleIntensityRanged(keys=["image"], a_min=-1000, a_max=400, b_min=0.0, b_max=1.0, clip=True),
    ResizeWithPadOrCropd(keys=["image"], spatial_size=TARGET_SIZE),
    AddGlobalResized(),
    EnsureTyped(keys=["image"])
])

# ---------------------
# Датасет
# ---------------------
train_ds = Dataset(data=train_datalist, transform=transforms)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=False)

# ---------------------
# Кастомный sliding window
# ---------------------
def custom_sliding_embeddings_encoders(model, volume, global_, roi_size=(96,96,96), overlap=0.25, device="cuda"):
    _, _, D, H, W = volume.shape
    d, h, w = roi_size

    stride_d = int(d * (1 - overlap))
    stride_h = int(h * (1 - overlap))
    stride_w = int(w * (1 - overlap))

    embeddings = []

    with torch.no_grad():
        for z in range(0, D, stride_d):
            for y in range(0, H, stride_h):
                for x in range(0, W, stride_w):
                    zz, yy, xx = min(z, D - d), min(y, H - h), min(x, W - w)
                    patch = volume[:, :, zz:zz+d, yy:yy+h, xx:xx+w]
                    patch = torch.cat([patch, global_], dim=1).to(device)

                    # 1) прогон через swinViT
                    hidden_states = model.swinViT(patch)

                    # 2) энкодеры
                    enc0 = model.encoder1(patch)             # low-level
                    enc1 = model.encoder2(hidden_states[0])
                    enc2 = model.encoder3(hidden_states[1])
                    enc3 = model.encoder4(hidden_states[2])
                    enc4 = model.encoder10(hidden_states[4])

                    # 3) spatial pooling
                    pooled = [
                        enc0.mean(dim=[2,3,4]),
                        enc1.mean(dim=[2,3,4]),
                        enc2.mean(dim=[2,3,4]),
                        enc3.mean(dim=[2,3,4]),
                        enc4.mean(dim=[2,3,4]),
                    ]

                    concat = torch.cat(pooled, dim=-1)   # [B, feat_dim]
                    embeddings.append(concat.cpu())

                    del patch, hidden_states, enc0, enc1, enc2, enc3, enc4, pooled, concat
                    torch.cuda.empty_cache()

    embeddings = torch.cat(embeddings, dim=0)  # [num_patches, feat_dim]
    return embeddings


# ---------------------
# Прогон по датасету
# ---------------------
for i, batch in enumerate(train_loader):
    img = batch["image"].to(device) 
    global_ = batch["global"].to(device) # [1,1,D,H,W]
    emb = custom_sliding_embeddings_encoders(model, img, global_, roi_size=roi_size, overlap=overlap, device=device)

    fname = os.path.basename(batch["name"][0])
    save_path = os.path.join(save_dir, fname.replace('.nii.gz', ".pt"))

    torch.save(emb, save_path)
    print(f"[{i+1}/{len(train_loader)}] Saved embedding {emb.shape} → {save_path}")


NameError: name 'Transform' is not defined

In [None]:
a

In [5]:

from transformers import BartTokenizer, BartForConditionalGeneration

# токенайзер и модель
tokenizer = BartTokenizer.from_pretrained("Mahalingam/DistilBart-Med-Summary")
bart = BartForConditionalGeneration.from_pretrained("Mahalingam/DistilBart-Med-Summary")



In [6]:
import numpy as np
from monai.transforms import (
    Compose, LoadImaged, Spacingd, Orientationd, ScaleIntensityRanged,
    CropForegroundd, RandFlipd, RandRotate90d, RandGaussianNoised,
    RandAdjustContrastd, RandShiftIntensityd, RandCoarseDropoutd,
    EnsureTyped, ToTensord, RandAffined, ResizeWithPadOrCropd, RandCropByPosNegLabeld
)
from monai.transforms import Transform

# --- кастом: загрузка сохранённых эмбеддингов ---
class LoadEmbeddingD(Transform):
    def __call__(self, data):
        if "embeddings" in data:
            path = data["embeddings"]
            if path.endswith(".pt") or path.endswith(".pth"):
                emb = torch.load(path, map_location="cpu")
            elif path.endswith(".npy"):
                emb = torch.from_numpy(np.load(path))
            else:
                raise ValueError(f"Unsupported embedding format: {path}")
            data["embeddings"] = emb.float()
        return data

# --- кастом трансформ для labels ---
class CastLabelsToFloatD(Transform):
    def __call__(self, data):
        if "labels" in data:
            data["labels"] = np.asarray(data["labels"]).astype(np.float32)
        return data
# --- кастом трансформ для текста ---
class TokenizeReportD(Transform):
    def __init__(self, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, data):
        text = data.get("report", "")

        # report может быть dict (finding+impression) или str
        if isinstance(text, dict):
            combined = text.get("finding", "") + " " + text.get("impression", "")
        else:
            combined = str(text)

        tokens = self.tokenizer(
            combined,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # для BART labels нужны (копия input_ids, паддинги заменяем на -100, чтобы не учитывались в лоссе)
        labels = tokens["input_ids"].clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        data["input_ids"] = tokens["input_ids"].squeeze(0)         # [seq_len]
        data["attention_mask"] = tokens["attention_mask"].squeeze(0)  # [seq_len]
        data["bart_labels"] = labels.squeeze(0)                         # [seq_len]
        return data


# ----------------- Параметры -------------------
TARGET_SPACING = (1.0, 1.0, 1.0)   # mm
TARGET_SIZE = (512, 512, 512)      # для валида и инференса
PATCH_SIZE = (96, 96, 96)       # патчи для тренировки
HU_MIN, HU_MAX = -1000.0, 400.0    # окно легких



# --- трансформы для multimodal (целый скан для эмбеддингов) ---
multimodal_transforms = Compose([
    TokenizeReportD(tokenizer=tokenizer, max_length=512),
    CastLabelsToFloatD(),
    LoadImaged(keys=["image", "mask"], ensure_channel_first=True),
    Orientationd(keys=["image", "mask"], axcodes="RAS"),
    Spacingd(keys=["image", "mask"], pixdim=TARGET_SPACING, mode=("bilinear", "nearest")),
    CropForegroundd(keys=["image", "mask"], source_key="image"),
    ScaleIntensityRanged(keys=["image"], a_min=HU_MIN, a_max=HU_MAX, b_min=0.0, b_max=1.0, clip=True),
    ResizeWithPadOrCropd(keys=["image", "mask"], spatial_size=TARGET_SIZE),
    EnsureTyped(keys=["image", "mask"]),
    ToTensord(keys=["image", "mask", "labels"]),
])

multimodal_embed_transforms = Compose([
    LoadEmbeddingD(),                      # грузим эмбеддинги из файла
    CastLabelsToFloatD(),                  # метки в float32
    TokenizeReportD(tokenizer=tokenizer, max_length=512),  # токенизация текста
    EnsureTyped(keys=["embeddings", "labels"]),
    ToTensord(keys=["embeddings", "labels", "input_ids", "attention_mask", "bart_labels"]),
])


2025-10-01 06:42:58.896066: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.


In [7]:
import torch
from torch.nn.utils.rnn import pad_sequence


def collate_fn(batch):
    """
    batch: список элементов от __getitem__ датасета
           каждый элемент — dict с:
             - embeddings: [E]
             - labels: [num_classes]
             - input_ids: [seq_len]
             - attention_mask: [seq_len]
             - bart_labels: [seq_len]
    """

    # собираем эмбеддинги и метки
    embeddings = [sample["embeddings"] for sample in batch]
    # паддинг до максимальной длины в этом батче
    embeddings = pad_sequence(embeddings, batch_first=True)  # [B, max_len, 1152]
    labels = torch.stack([sample["labels"] for sample in batch]).float()   # [B, num_classes]

    # паддинг токенов по max длине
    input_ids = torch.nn.utils.rnn.pad_sequence(
        [sample["input_ids"] for sample in batch],
        batch_first=True,
        padding_value=0
    )
    attention_mask = torch.nn.utils.rnn.pad_sequence(
        [sample["attention_mask"] for sample in batch],
        batch_first=True,
        padding_value=0
    )
    bart_labels = torch.nn.utils.rnn.pad_sequence(
        [sample["bart_labels"] for sample in batch],
        batch_first=True,
        padding_value=-100  # у BART паддинги должны быть -100
    )

    return {
        "embeddings": embeddings,        # [B, E]
        "labels": labels,                # [B, num_classes]
        "input_ids": input_ids,          # [B, L]
        "attention_mask": attention_mask,# [B, L]
        "bart_labels": bart_labels       # [B, L]
    }


In [8]:
from monai.data import Dataset
from monai.data.dataset import CacheDataset

train_ds = CacheDataset(data=train_datalist, transform=multimodal_embed_transforms, num_workers=8)
val_ds = CacheDataset(data=valid_datalist, transform=multimodal_embed_transforms, num_workers=8)


Loading dataset: 100%|██████████| 939/939 [00:13<00:00, 68.99it/s]
Loading dataset: 100%|██████████| 50/50 [00:00<00:00, 67.25it/s]


In [9]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=50, shuffle=False, num_workers=2, collate_fn=collate_fn)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput


# === Graph Attention Layer для CLS ===
class SimpleGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0.1, alpha=0.2):
        super().__init__()
        self.W = nn.Linear(in_dim, out_dim, bias=False)
        self.a = nn.Linear(2 * out_dim, 1, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(alpha)

    def forward(self, h):
        B, N, D = h.shape
        Wh = self.W(h)  # [B, N, out_dim]
        Wh_repeat1 = Wh.unsqueeze(2).repeat(1, 1, N, 1)  # [B, N, N, D]
        Wh_repeat2 = Wh.unsqueeze(1).repeat(1, N, 1, 1)  # [B, N, N, D]
        e = self.leakyrelu(self.a(torch.cat([Wh_repeat1, Wh_repeat2], dim=-1))).squeeze(-1)  # [B, N, N]

        attn = F.softmax(e, dim=-1)
        attn = self.dropout(attn)
        h_prime = torch.bmm(attn, Wh)  # [B, N, out_dim]
        return h_prime


# === Cross-Attention Bridge c CLS и GEN потоками ===
class CrossAttentionBridge(nn.Module):
    def __init__(self, embed_dim, bart_dim, num_heads=8, num_cls=18, num_gen=8, dropout=0.1):
        super().__init__()
        self.proj = nn.Linear(embed_dim, bart_dim)
        self.proj_norm = nn.LayerNorm(bart_dim)

        self.cls_queries = nn.Parameter(torch.randn(1, num_cls, bart_dim))
        self.gen_queries = nn.Parameter(torch.randn(1, num_gen, bart_dim))

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=bart_dim, num_heads=num_heads, dropout=dropout, batch_first=True
        )
        self.norm = nn.LayerNorm(bart_dim)
        self.dropout = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=bart_dim, nhead=8, batch_first=True, dropout=dropout
        )
        self.cls_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        self.gen_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

        self.gat = SimpleGATLayer(bart_dim, bart_dim, dropout=dropout)

    def forward(self, patch_embeddings):
        B = patch_embeddings.size(0)
        patches = self.proj_norm(self.proj(patch_embeddings))

        # CLS поток
        cls_queries = self.cls_queries.expand(B, -1, -1)
        cls_attended, _ = self.cross_attn(query=cls_queries, key=patches, value=patches)
        cls_attended = self.norm(cls_attended + cls_queries)
        cls_attended = self.dropout(cls_attended)
        cls_attended = self.cls_encoder(cls_attended)
        cls_attended = self.gat(cls_attended)

        # GEN поток
        gen_queries = self.gen_queries.expand(B, -1, -1)
        gen_attended, _ = self.cross_attn(query=gen_queries, key=patches, value=patches)
        gen_attended = self.norm(gen_attended + gen_queries)
        gen_attended = self.dropout(gen_attended)
        gen_attended = self.gen_encoder(gen_attended)

        return cls_attended, gen_attended


# === MoE Classifier ===
class MoEClassifier(nn.Module):
    def __init__(self, bart_dim, num_classes=18, num_experts=4, hidden_dim=512, dropout=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.num_experts = num_experts

        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(bart_dim, hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.LayerNorm(hidden_dim),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, num_classes)
            )
            for _ in range(num_experts)
        ])

        self.gate = nn.Sequential(
            nn.Linear(bart_dim, num_experts),
            nn.Softmax(dim=-1)
        )

    def forward(self, cls_reprs):
        pooled = cls_reprs.mean(dim=1)
        gate_weights = self.gate(pooled)

        expert_outputs = []
        for expert in self.experts:
            out = expert(pooled)
            expert_outputs.append(out.unsqueeze(2))

        expert_outputs = torch.cat(expert_outputs, dim=2)  # [B, num_classes, num_experts]
        gate_weights = gate_weights.unsqueeze(1)           # [B, 1, num_experts]

        logits = torch.bmm(expert_outputs, gate_weights.transpose(1, 2)).squeeze(-1)
        return logits


# === Uncertainty Weighting ===
class UncertaintyWeighting(nn.Module):
    def __init__(self):
        super().__init__()
        self.log_vars = nn.Parameter(torch.zeros(2))

    def forward(self, cls_loss, bart_loss):
        cls_loss_weighted = torch.exp(-self.log_vars[0]) * cls_loss + self.log_vars[0]
        bart_loss_weighted = torch.exp(-self.log_vars[1]) * bart_loss + self.log_vars[1]
        return cls_loss_weighted + bart_loss_weighted


# === Coverage Loss ===
def compute_coverage_loss(cross_attentions, eps=1e-8):
    attns = [att.mean(dim=1) for att in cross_attentions]  # [batch, tgt_len, src_len]
    attn = torch.stack(attns).mean(dim=0)                  # [batch, tgt_len, src_len]

    coverage = torch.zeros_like(attn[:, 0, :])
    cov_loss = 0.0
    for t in range(attn.size(1)):
        step_attn = attn[:, t, :]
        cov_loss += torch.sum(torch.min(step_attn, coverage), dim=-1).mean()
        coverage = coverage + step_attn

    return cov_loss / (attn.size(1) + eps)


# === Основная модель ===
import torch
import torch.nn as nn
from transformers import BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput


class MultiModalMonsterUltimate(nn.Module):
    def __init__(self, bart_ckpt="facebook/bart-base",
                 embed_dim=1152, num_classes=18, dropout=0.1,
                 num_experts=4, num_gen_queries=8,
                 lambda_bart=0.2):
        super().__init__()

        self.bart = BartForConditionalGeneration.from_pretrained(bart_ckpt)

        self.bridge = CrossAttentionBridge(
            embed_dim=embed_dim,
            bart_dim=self.bart.config.d_model,
            num_heads=8,
            num_cls=num_classes,
            num_gen=num_gen_queries,
            dropout=dropout
        )

        # оставляем MoE, но можно заменить на простой Linear для дебага
        self.classifier = MoEClassifier(
            bart_dim=self.bart.config.d_model,
            num_classes=num_classes,
            num_experts=num_experts,
            hidden_dim=512,
            dropout=dropout
        )

        self.num_classes = num_classes
        self.lambda_bart = lambda_bart  # вес генерации

    def forward(self, embeddings, labels=None, bart_labels=None,
                classification_loss=None, use_soft_prompt=False, tokenizer=None, topk=3):
        out = {}

        cls_repr, gen_repr = self.bridge(embeddings)

        # --- Классификация ---
        logits_cls = self.classifier(cls_repr)
        out["logits_cls"] = logits_cls

        if labels is not None and classification_loss is not None:
            cls_loss = classification_loss(logits_cls, labels.float())
        else:
            cls_loss = torch.tensor(0.0, device=embeddings.device)

        # --- Генерация ---
        bart_loss = torch.tensor(0.0, device=embeddings.device)
        coverage_loss = torch.tensor(0.0, device=embeddings.device)
        bridge_out = gen_repr

        if bart_labels is not None:
            if use_soft_prompt and tokenizer is not None:
                probs = torch.sigmoid(logits_cls)
                topk_idx = torch.topk(probs, k=min(topk, self.num_classes), dim=-1).indices[0].tolist()
                prompt_text = "Findings: " + ", ".join([f"pathology_{i}" for i in topk_idx]) + ". Report:"
                prompt_tokens = tokenizer.encode(prompt_text, return_tensors="pt").to(embeddings.device)
                bart_labels = torch.cat([prompt_tokens, bart_labels], dim=1)

            bart_out = self.bart(
                encoder_outputs=BaseModelOutput(last_hidden_state=bridge_out),
                labels=bart_labels,
                output_attentions=True,
                return_dict=True
            )
            bart_loss = bart_out.loss
            coverage_loss = compute_coverage_loss(bart_out.cross_attentions)
            out["bart_out"] = bart_out

        # --- Лоссы ---
        out["cls_loss"] = cls_loss
        out["bart_loss"] = bart_loss
        out["coverage_loss"] = coverage_loss

        # фиксированный баланс
        out["loss"] = cls_loss + self.lambda_bart * (bart_loss + coverage_loss)

        return out



In [15]:
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from nltk.translate.bleu_score import sentence_bleu

# BCEWithLogitsLoss с весами для мульти-лейбл классификации
cls_loss = torch.nn.BCEWithLogitsLoss()

def step_fn(model, batch, optimizer=None, device="cuda", val=False, tokenizer=None):
    """
    Один шаг тренировки или валидации для MultiModalMonsterUltimate.
    Возвращает loss и метрики.
    """
    model.train(optimizer is not None)
    batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}

    # прямой проход
    outputs = model(
        embeddings=batch['embeddings'],
        labels=batch['labels'],
        bart_labels=batch['bart_labels'],
        classification_loss=cls_loss,
        use_soft_prompt=True if val else False,  # во время валидации можно подсунуть prompt
    )

    loss = outputs["loss"]

    # шаг оптимизации
    if optimizer:
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    # метрики
    metrics = {}
    if val:
        # === Классификация ===
        if "logits_cls" in outputs and "labels" in batch:
            preds = (torch.sigmoid(outputs["logits_cls"]) > 0.5).int().cpu().numpy()
            labels = batch["labels"].cpu().numpy().astype(int)

            metrics["f1_micro"] = f1_score(labels, preds, average="micro")
            metrics["f1_macro"] = f1_score(labels, preds, average="macro")

        # === Генерация текста ===
        if "bart_out" in outputs and "bart_labels" in batch and tokenizer is not None:
            # gold labels
            refs = batch["bart_labels"].detach().cpu().tolist()
            refs = [[tok for tok in seq if tok is not None and tok != -100] for seq in refs]

            # предсказания из логитов
            preds = outputs["bart_out"].logits.argmax(-1).detach().cpu().tolist()

            # декодирование
            refs_text = tokenizer.batch_decode(refs, skip_special_tokens=True)
            preds_text = tokenizer.batch_decode(preds, skip_special_tokens=True)

            bleu_scores = [
                sentence_bleu([r.split()], h.split()) if len(r) > 0 and len(h) > 0 else 0.0
                for r, h in zip(refs_text, preds_text)
            ]
            metrics["bleu"] = sum(bleu_scores) / max(1, len(bleu_scores))

        # === Coverage Loss ===
        if "coverage_loss" in outputs:
            metrics["coverage_loss"] = outputs["coverage_loss"].item()

        # Отдельно лоссы
        metrics["cls_loss"] = outputs["cls_loss"].item()
        metrics["bart_loss"] = outputs["bart_loss"].item()

    return loss.item(), metrics


In [16]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="nltk")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
from tqdm import tqdm

# === Инициализация ===
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MultiModalMonsterUltimate().to(device)
# model.load_state_dict(torch.load('best_monster.pth')["model_state"])
cls_loss = nn.BCEWithLogitsLoss()

# разделяем параметры (BART vs остальные)
bart_params, other_params = [], []
for name, param in model.named_parameters():
    if "bart" in name:
        bart_params.append(param)
    else:
        other_params.append(param)

optimizer = AdamW([
    {"params": other_params, "lr": 1e-4, "weight_decay": 0.01},
    {"params": bart_params, "lr": 1e-6, "weight_decay": 0.01}
])

num_epochs = 2000
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)

best_val_bleu = -float("inf")  # чем больше BLEU, тем лучше

# === Цикл обучения ===
for epoch in range(num_epochs):
    model.train()
    total_loss, total_f1_micro, total_f1_macro, total_bleu = 0, 0, 0, 0
    n_batches = len(train_loader)

    # --- Тренировка ---
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} [train]"):
        loss, metrics = step_fn(model, batch, optimizer, device, val=False, tokenizer=tokenizer)
        total_loss += loss
        total_f1_micro += metrics.get("f1_micro", 0)
        total_f1_macro += metrics.get("f1_macro", 0)
        total_bleu += metrics.get("bleu", 0)

    scheduler.step()

    avg_train_loss = total_loss / n_batches
    avg_train_f1_micro = total_f1_micro / n_batches
    avg_train_f1_macro = total_f1_macro / n_batches
    avg_train_bleu = total_bleu / n_batches

    # --- Валидация ---
    avg_val_loss, avg_val_f1_micro, avg_val_f1_macro, avg_val_bleu = 0, 0, 0, 0
    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [val]"):
                loss, metrics = step_fn(model, batch, None, device, val=True, tokenizer=tokenizer)
                avg_val_loss += loss
                avg_val_f1_micro += metrics.get("f1_micro", 0)
                avg_val_f1_macro += metrics.get("f1_macro", 0)
                avg_val_bleu += metrics.get("bleu", 0)

        avg_val_loss /= len(val_loader)
        avg_val_f1_micro /= len(val_loader)
        avg_val_f1_macro /= len(val_loader)
        avg_val_bleu /= len(val_loader)

        # --- лог ---
        lrs = [pg["lr"] for pg in optimizer.param_groups]
        lr_str = " | ".join([f"{lr:.6e}" for lr in lrs])

        print(f"Epoch {epoch+1}/{num_epochs} || "
              f"Train Loss: {avg_train_loss:.4f} | F1_micro: {avg_train_f1_micro:.4f} | "
              f"F1_macro: {avg_train_f1_macro:.4f} | BLEU: {avg_train_bleu:.4f} || "
              f"Val Loss: {avg_val_loss:.4f} | F1_micro: {avg_val_f1_micro:.4f} | "
              f"F1_macro: {avg_val_f1_macro:.4f} | BLEU: {avg_val_bleu:.4f} || "
              f"LRs: {lr_str}")

        # --- чекпоинт ---
        if avg_val_bleu > best_val_bleu:
            best_val_bleu = avg_val_bleu
            save_path = os.path.join(save_dir, f"epoch{epoch+1}_valBLEU{avg_val_bleu:.4f}.pt")
            torch.save({
                "epoch": epoch + 1,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "val_loss": avg_val_loss,
                "val_bleu": avg_val_bleu,
            }, save_path)
            print(f"💾 Сохранён лучший чекпоинт: {save_path}")


Epoch 1 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.54s/it]
Epoch 1 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 1 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]


Epoch 1/2000 || Train Loss: 2.3796 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.9772 | F1_micro: 0.2613 | F1_macro: 0.0380 | BLEU: 0.0000 || LRs: 9.999994e-05 | 9.999994e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch1_valBLEU0.0000.pt


Epoch 2 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 3 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 4 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 5 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 6 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 7 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 8 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 9 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 10 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 11 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 11 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 11 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]


Epoch 11/2000 || Train Loss: 2.0285 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.8832 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || LRs: 9.999254e-05 | 9.999254e-07


Epoch 12 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 13 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 14 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 15 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 16 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 17 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 18 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 19 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 20 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 21 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 21 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 21 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]


Epoch 21/2000 || Train Loss: 1.9976 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.8578 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || LRs: 9.997280e-05 | 9.997280e-07


Epoch 22 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 23 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.54s/it]
Epoch 24 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 25 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 26 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 27 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 28 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.50s/it]
Epoch 29 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 30 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 31 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 31 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 31 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.93s/it]


Epoch 31/2000 || Train Loss: 1.9741 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.8374 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || LRs: 9.994073e-05 | 9.994073e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch31_valBLEU0.0000.pt


Epoch 32 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 33 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 34 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 35 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 36 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 37 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 38 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 39 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.55s/it]
Epoch 40 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 41 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 41 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 41 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]


Epoch 41/2000 || Train Loss: 1.9424 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.7539 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || LRs: 9.989634e-05 | 9.989634e-07


Epoch 42 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 43 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 44 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 45 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 46 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 47 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 48 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 49 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 50 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 51 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 51 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 51 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]


Epoch 51/2000 || Train Loss: 1.5918 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.3124 | F1_micro: 0.1667 | F1_macro: 0.0292 | BLEU: 0.0271 || LRs: 9.983964e-05 | 9.983964e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch51_valBLEU0.0271.pt


Epoch 52 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 53 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 54 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 55 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 56 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 57 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 58 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 59 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 60 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 61 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 61 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 61 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]

Epoch 61/2000 || Train Loss: 1.3980 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.1659 | F1_micro: 0.1657 | F1_macro: 0.0287 | BLEU: 0.0573 || LRs: 9.977065e-05 | 9.977065e-07





💾 Сохранён лучший чекпоинт: checkpoints/epoch61_valBLEU0.0573.pt


Epoch 62 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 63 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 64 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 65 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 66 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.54s/it]
Epoch 67 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 68 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 69 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 70 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 71 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 71 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 71 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]


Epoch 71/2000 || Train Loss: 1.2876 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.0585 | F1_micro: 0.1724 | F1_macro: 0.0431 | BLEU: 0.0918 || LRs: 9.968937e-05 | 9.968937e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch71_valBLEU0.0918.pt


Epoch 72 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 73 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.50s/it]
Epoch 74 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.50s/it]
Epoch 75 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 76 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 77 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 78 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 79 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 80 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.50s/it]
Epoch 81 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 81 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 81 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.92s/it]


Epoch 81/2000 || Train Loss: 1.2147 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.0025 | F1_micro: 0.2069 | F1_macro: 0.0511 | BLEU: 0.1132 || LRs: 9.959583e-05 | 9.959583e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch81_valBLEU0.1132.pt


Epoch 82 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 83 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 84 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 85 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 86 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 87 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 88 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 89 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.50s/it]
Epoch 90 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 91 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 91 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 91 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.95s/it]


Epoch 91/2000 || Train Loss: 1.1462 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 0.9888 | F1_micro: 0.3143 | F1_macro: 0.1371 | BLEU: 0.1768 || LRs: 9.949006e-05 | 9.949006e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch91_valBLEU0.1768.pt


Epoch 92 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 93 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 94 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 95 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 96 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 97 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 98 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 99 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 100 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 101 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 101 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 101 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.92s/it]


Epoch 101/2000 || Train Loss: 1.0899 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 0.9710 | F1_micro: 0.3544 | F1_macro: 0.2112 | BLEU: 0.1936 || LRs: 9.937207e-05 | 9.937207e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch101_valBLEU0.1936.pt


Epoch 102 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 103 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 104 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 105 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 106 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 107 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 108 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 109 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 110 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 111 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 111 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 111 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.93s/it]


Epoch 111/2000 || Train Loss: 1.0099 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 0.9644 | F1_micro: 0.3320 | F1_macro: 0.1597 | BLEU: 0.2434 || LRs: 9.924190e-05 | 9.924190e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch111_valBLEU0.2434.pt


IOStream.flush timed out
Epoch 112 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 113 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 114 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 115 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 116 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 117 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 118 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 119 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 120 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 121 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 121 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.92s/it]


Epoch 121/2000 || Train Loss: 0.9382 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 0.9890 | F1_micro: 0.4066 | F1_macro: 0.2967 | BLEU: 0.2570 || LRs: 9.909959e-05 | 9.909959e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch121_valBLEU0.2570.pt


Epoch 122 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 123 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 124 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 125 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 126 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 127 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 128 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 129 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 130 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 131 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.54s/it]
Epoch 131 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]


Epoch 131/2000 || Train Loss: 0.8400 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 0.9882 | F1_micro: 0.3789 | F1_macro: 0.3046 | BLEU: 0.2643 || LRs: 9.894515e-05 | 9.894515e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch131_valBLEU0.2643.pt


Epoch 132 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 133 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 134 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 135 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 136 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 137 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 138 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 139 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 140 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 141 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 141 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]


Epoch 141/2000 || Train Loss: 0.7573 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 0.9843 | F1_micro: 0.3569 | F1_macro: 0.2224 | BLEU: 0.2911 || LRs: 9.877864e-05 | 9.877864e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch141_valBLEU0.2911.pt


Epoch 142 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 143 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 144 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 145 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 146 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 147 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.54s/it]
Epoch 148 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 149 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 150 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 151 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 151 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.92s/it]


Epoch 151/2000 || Train Loss: 0.6888 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.0467 | F1_micro: 0.3310 | F1_macro: 0.1941 | BLEU: 0.3033 || LRs: 9.860010e-05 | 9.860010e-07


Epoch 155 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 156 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 157 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 158 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 160 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 161 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 161 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]


Epoch 161/2000 || Train Loss: 0.6485 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.0401 | F1_micro: 0.3259 | F1_macro: 0.1782 | BLEU: 0.3084 || LRs: 9.840957e-05 | 9.840957e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch161_valBLEU0.3084.pt


Epoch 162 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 163 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 166 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 167 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 168 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 169 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 170 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 171 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 171 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]


Epoch 171/2000 || Train Loss: 0.6173 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.0444 | F1_micro: 0.3650 | F1_macro: 0.1733 | BLEU: 0.2809 || LRs: 9.820709e-05 | 9.820709e-07


Epoch 172 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 173 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 174 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 175 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 176 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 177 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 178 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 179 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 180 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.55s/it]
Epoch 181 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 181 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]


Epoch 181/2000 || Train Loss: 0.5943 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.1008 | F1_micro: 0.3213 | F1_macro: 0.1812 | BLEU: 0.3225 || LRs: 9.799271e-05 | 9.799271e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch181_valBLEU0.3225.pt


Epoch 182 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.51s/it]
Epoch 183 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 184 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 185 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 186 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 187 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 188 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 189 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 190 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 191 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.54s/it]
Epoch 191 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]


Epoch 191/2000 || Train Loss: 0.5722 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.1278 | F1_micro: 0.3728 | F1_macro: 0.2163 | BLEU: 0.3250 || LRs: 9.776650e-05 | 9.776650e-07
💾 Сохранён лучший чекпоинт: checkpoints/epoch191_valBLEU0.3250.pt


Epoch 192 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 193 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 194 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 195 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 201 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 201 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.93s/it]


Epoch 201/2000 || Train Loss: 0.5590 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.2161 | F1_micro: 0.3427 | F1_macro: 0.2009 | BLEU: 0.3166 || LRs: 9.752850e-05 | 9.752850e-07


Epoch 202 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.55s/it]
Epoch 203 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 204 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 205 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 206 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 207 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 208 [train]: 100%|██████████| 15/15 [00:23<00:00,  1.54s/it]
Epoch 209 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 210 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 211 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 211 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 211 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.95s/it]


Epoch 211/2000 || Train Loss: 0.5534 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.1495 | F1_micro: 0.3468 | F1_macro: 0.1860 | BLEU: 0.3238 || LRs: 9.727877e-05 | 9.727877e-07


Epoch 212 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 213 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 214 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 215 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 216 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 217 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 218 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 219 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 220 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 221 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 221 [val]:   0%|          | 0/1 [00:00<?, ?it/s]F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.
Epoch 221 [val]: 100%|██████████| 1/1 [00:01<00:00,  1.93s/it]


Epoch 221/2000 || Train Loss: 0.5473 | F1_micro: 0.0000 | F1_macro: 0.0000 | BLEU: 0.0000 || Val Loss: 1.1830 | F1_micro: 0.3309 | F1_macro: 0.1949 | BLEU: 0.2873 || LRs: 9.701738e-05 | 9.701738e-07


Epoch 222 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 223 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 224 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 225 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 226 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
Epoch 227 [train]: 100%|██████████| 15/15 [00:22<00:00,  1.53s/it]
Epoch 228 [train]:  87%|████████▋ | 13/15 [00:21<00:03,  1.63s/it]


KeyboardInterrupt: 

In [19]:
model.load_state_dict(torch.load("checkpoints/epoch191_valBLEU0.3250.pt")["model_state"])

<All keys matched successfully>

In [21]:
import torch
from transformers.modeling_outputs import BaseModelOutput

def debug_one_batch_monster(model, dataloader, tokenizer, device="cuda", idx=0, thr=0.4):
    model.eval()
    batch = next(iter(dataloader))  # берём один батч

    # --- На девайс ---
    embeddings = batch["embeddings"].to(device)          # [B, N, D]
    labels = batch["labels"].to(device)                  # [B, num_classes]
    bart_labels = batch["bart_labels"].to(device)        # [B, T]

    with torch.no_grad():
        outputs = model(
            embeddings,
            labels=labels,
            bart_labels=bart_labels,
            classification_loss=torch.nn.BCEWithLogitsLoss()
        )

    # === КЛАССИФИКАЦИЯ ===
    true_labels = labels[idx].cpu().numpy()
    pred_logits = outputs["logits_cls"][idx].cpu()
    pred_probs = torch.sigmoid(pred_logits).numpy()
    pred_classes = (pred_probs > thr).astype(int)

    # === ГЕНЕРАЦИЯ ТЕКСТА ===
    # правильный отчёт (референс)
    refs = bart_labels[idx].detach().cpu().tolist()
    refs = [tok for tok in refs if tok is not None and tok != -100]
    ref_text = tokenizer.decode(refs, skip_special_tokens=True)

    # генерация (beam search)
    bridge_out = outputs.get("bart_out", None)
    if bridge_out is None:
        # если в out не сохранили, достанем заново
        _, gen_repr = model.bridge(embeddings)
        enc_out = BaseModelOutput(last_hidden_state=gen_repr)
    else:
        enc_out = BaseModelOutput(last_hidden_state=outputs["bart_out"].encoder_last_hidden_state)

    generated_ids = model.bart.generate(
        encoder_outputs=enc_out,
        max_length=1024,
        num_beams=4
    )
    hyp_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # === ВЫВОД ===
    print("\n=== Классификация ===")
    print("Истинные классы:", true_labels)
    print("Предсказанные :", pred_classes)
    print("Вероятности   :", pred_probs.round(3))

    print("\n=== Тексты ===")
    print("GT :", ref_text)
    print("PR :", hyp_text)

    return {
        "true_labels": true_labels,
        "pred_classes": pred_classes,
        "pred_probs": pred_probs,
        "ref_text": ref_text,
        "hyp_text": hyp_text
    }


In [23]:
debug_one_batch_monster(model, val_loader, tokenizer, device="cuda")


=== Классификация ===
Истинные классы: [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
Предсказанные : [0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
Вероятности   : [0.003 0.002 0.    0.008 0.001 0.    0.001 0.065 0.001 0.1   0.04  0.53
 0.    0.001 0.    0.004 0.    0.028]

=== Тексты ===
GT : Trachea, both main bronchi are open. Mediastinal main vascular structures, heart contour, size are normal. Thoracic aorta diameter is normal. Pericardial effusion-thickening was not observed. Thoracic esophageal calibration was normal and no significant tumoral wall thickening was detected. No enlarged lymph nodes in prevascular, pre-paratracheal, subcarinal or bilateral hilar-axillary pathological dimensions were detected. When examined in the lung parenchyma window; A few millimetric nonspecific nodules and mild recessions are observed in the upper lobe and lower lobe of the right lung. Aeration of both lung parenchyma is normal and no infiltrative lesion is detected in the lung parenchyma. P

{'true_labels': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0.], dtype=float32),
 'pred_classes': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]),
 'pred_probs': array([3.0590175e-03, 2.1879892e-03, 1.3973824e-04, 7.6949741e-03,
        5.9485430e-04, 1.1335634e-04, 5.6751986e-04, 6.5188758e-02,
        9.0039906e-04, 1.0023847e-01, 3.9702304e-02, 5.2968460e-01,
        4.8123969e-04, 1.0277594e-03, 3.2878856e-04, 3.5577000e-03,
        3.3481632e-04, 2.8099500e-02], dtype=float32),
 'ref_text': 'Trachea, both main bronchi are open. Mediastinal main vascular structures, heart contour, size are normal. Thoracic aorta diameter is normal. Pericardial effusion-thickening was not observed. Thoracic esophageal calibration was normal and no significant tumoral wall thickening was detected. No enlarged lymph nodes in prevascular, pre-paratracheal, subcarinal or bilateral hilar-axillary pathological dimensions were detected. When examined in 