In [None]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
from addict import Dict
from pathlib import Path
import torch
from stability.models.cnn import CBRNet, cnn_loss
from stability.data import CellDataset
from stability.train import train
import yaml
import json
import os

First, let's construct data loaders. We need to refer to the `split.csv` file, which describes the paths to patches in the train, dev, and test sets.

In [18]:
root_dir = Path(os.environ["ROOT_DIR"])
data_dir = Path(os.environ["DATA_DIR"])
opts = Dict(yaml.safe_load(open(root_dir / "conf/tnbc_cnn.yaml", "r")))

split_path = data_dir / opts.organization.xy
split = pd.read_csv(split_path)
sp = ("train", "dev", "test")
img_paths = {k: split[split["split"] == k].path.values for k in sp}
ds = {k: CellDataset(img_paths[k], split_path, data_dir) for k in sp}
loaders = {k: DataLoader(ds[k], batch_size=128) for k in sp}

Now, let's create a directory for putting all our feature activations and training logs.

In [19]:
features_dir = data_dir / opts.organization.features_dir
writer = SummaryWriter(features_dir / "logs")
writer.add_text("conf", json.dumps(opts))

Finally, we can train our model. Note that we have 7 input channels, since we are using 7 (grouped) cell types. Other than that, training is exactly like it was for the original cells data.

In [None]:
model = CBRNet(p_in=7)
optim = torch.optim.Adam(model.parameters(), lr=opts.train.lr)

out_paths = [
    data_dir / opts.organization.features_dir, 
    data_dir / opts.organization.metadata,
    data_dir / "model_final.pt"
]
train(model, optim, loaders, opts, out_paths, writer, cnn_loss)

0/30
