In [6]:
import numpy as np
import torch
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [7]:
def get_views(npz_path):
    data = np.load(npz_path)
    vol = data["slices"]

    print(f"Volume shape: {vol.shape}")
    Z, Y, X = vol.shape

    mid_z = Z // 2
    axial = [
        vol[mid_z],
        vol[max(mid_z - 2, 0)],
        vol[min(mid_z + 2, Z - 1)],
    ]

    mid_y = Y // 2
    coronal = [
        vol[:, mid_y, :],
        vol[:, max(mid_y - 2, 0), :],
        vol[:, min(mid_y + 2, Y - 1), :],
    ]

    mid_x = X // 2
    sagittal = [
        vol[:, :, mid_x],
        vol[:, :, max(mid_x - 2, 0)],
        vol[:, :, min(mid_x + 2, X - 1)],
    ]

    views = axial + coronal + sagittal

    views = [v.astype(np.float32) for v in views]

    return views

get_views("patches/subset0/1.3.6.1.4.1.14519.5.2.1.6279.6001.105756658031515062000744821260.npz")

Volume shape: (64, 64, 64)


[array([[0.11333334, 0.18266666, 0.45666668, ..., 0.79      , 0.79333335,
         0.796     ],
        [0.126     , 0.246     , 0.5513333 , ..., 0.794     , 0.7926667 ,
         0.79933333],
        [0.14933333, 0.286     , 0.58      , ..., 0.796     , 0.79      ,
         0.7873333 ],
        ...,
        [0.196     , 0.18      , 0.15933333, ..., 0.762     , 0.74866664,
         0.74866664],
        [0.17866667, 0.168     , 0.14066666, ..., 0.75133336, 0.754     ,
         0.76533335],
        [0.18466666, 0.16533333, 0.14666666, ..., 0.75266665, 0.7646667 ,
         0.768     ]], shape=(64, 64), dtype=float32),
 array([[0.48533332, 0.51      , 0.59866667, ..., 0.772     , 0.78466666,
         0.7826667 ],
        [0.498     , 0.5453333 , 0.64133334, ..., 0.7733333 , 0.7786667 ,
         0.786     ],
        [0.49466667, 0.546     , 0.676     , ..., 0.76066667, 0.77      ,
         0.7773333 ],
        ...,
        [0.198     , 0.19      , 0.204     , ..., 0.766     , 0.76266664,
   

In [8]:
from torchvision.models import resnet50

class FPRModel(nn.Module):
    def __init__(self, num_views=9, out_dim=1):
        super().__init__()

        base = resnet50(weights=None)

        base.conv1 = nn.Conv2d(
            1, 64, kernel_size=7, stride=2, padding=3, bias=False
        )

        self.backbone = nn.Sequential(*list(base.children())[:-1])
        self.feature_dim = 2048

        self.num_views = num_views

        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim * num_views, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, out_dim)
        )

    def forward(self, views):
        if isinstance(views, torch.Tensor):
            B, N, C, H, W = views.shape
            assert N == self.num_views
            views = [views[:, i, :, :, :] for i in range(N)]

        features = []
        for v in views:
            f = self.backbone(v)
            f = f.squeeze(-1).squeeze(-1)
            features.append(f)

        feats = torch.cat(features, dim=1)

        out = self.classifier(feats)

        return out
        

In [12]:
from torch.utils.data import Dataset, DataLoader

class MultiViewDataset(Dataset):
    def __init__(self, npz_files, labels):
        self.files = npz_files
        self.labels = labels
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        vol = data["slices"]

        views = get_views(self.files[idx])  

        views = [torch.tensor(v, dtype=torch.float32).unsqueeze(0) for v in views]

        views = torch.stack(views, dim=0)

        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        return views, label

In [10]:
from tqdm import tqdm 

def train_multiview(model, train_loader, val_loader, device, epochs=20):
    model = model.to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    best_val_loss = float("inf")

    for epoch in range(epochs):

        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        
        for views, labels in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
            # views: (B, 9, 1, H, W)
            views = views.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward
            outputs = model(views)        # shape (B,1)
            outputs = outputs.squeeze(1)  # shape (B,)

            # Loss
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * labels.size(0)

            # Accuracy
            preds = torch.sigmoid(outputs) > 0.5
            correct += (preds == labels.bool()).sum().item()
            total += labels.size(0)

        train_loss /= total
        train_acc = correct / total
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for views, labels in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
                views = views.to(device)
                labels = labels.to(device)

                outputs = model(views)
                outputs = outputs.squeeze(1)

                loss = criterion(outputs, labels)
                val_loss += loss.item() * labels.size(0)

                preds = torch.sigmoid(outputs) > 0.5
                correct += (preds == labels.bool()).sum().item()
                total += labels.size(0)

        val_loss /= total
        val_acc = correct / total
        
        print(f"""
        Epoch {epoch}:
            Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}
            Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}
        """)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")
            print("Saved best model")

In [11]:
train_dataset = MultiViewDataset(train_files, train_labels)
val_dataset = MultiViewDataset(val_files, val_labels)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)


NameError: name 'train_files' is not defined

In [None]:

model = FPRModel(num_views=9, out_dim=1)
train_multiview(model, train_loader, val_loader, device="cuda")
