# LabTOP Training on Google Colab TPU

## Step 1: Mount Google Drive & Setup Directories


In [None]:
from google.colab import drive
drive.mount('/content/drive')

!mkdir -p /content/drive/MyDrive/mimiciv/icu
!mkdir -p /content/drive/MyDrive/mimiciv/hosp


## Step 2: Download MIMIC-IV Data


In [None]:
# Use existing MIMIC-IV data stored in Drive
mimic_icu = "/content/drive/MyDrive/mimiciv/icu"
mimic_hosp = "/content/drive/MyDrive/mimiciv/hosp"

print("Using existing MIMIC-IV data from Drive:")
!ls -lh $mimic_icu
!ls -lh $mimic_hosp


In [None]:
import getpass
import os
import shutil

print("Logging into PhysioNet")
user = "removed for security"
password = "removed for security"

netrc_path = "/root/.netrc"
with open(netrc_path, "w") as f:
    f.write(f"machine physionet.org login {user} password {password}\n")
    f.write(f"machine content.physionet.org login {user} password {password}\n")
os.chmod(netrc_path, 0o600)
print("Authentication configured.")

icu_files = {
    "icustays.csv.gz":       "https://physionet.org/files/mimiciv/2.2/icu/icustays.csv.gz",
    "inputevents.csv.gz":    "https://physionet.org/files/mimiciv/2.2/icu/inputevents.csv.gz",
    "procedureevents.csv.gz":"https://physionet.org/files/mimiciv/2.2/icu/procedureevents.csv.gz",
    "outputevents.csv.gz":   "https://physionet.org/files/mimiciv/2.2/icu/outputevents.csv.gz",
    "d_items.csv.gz":        "https://physionet.org/files/mimiciv/2.2/icu/d_items.csv.gz",
}

hosp_files = {
    "admissions.csv.gz":     "https://physionet.org/files/mimiciv/2.2/hosp/admissions.csv.gz",
    "patients.csv.gz":       "https://physionet.org/files/mimiciv/2.2/hosp/patients.csv.gz",
    "labevents.csv.gz":      "https://physionet.org/files/mimiciv/2.2/hosp/labevents.csv.gz",
    "d_labitems.csv.gz":     "https://physionet.org/files/mimiciv/2.2/hosp/d_labitems.csv.gz",
    "microbiologyevents.csv.gz":"https://physionet.org/files/mimiciv/2.2/hosp/microbiologyevents.csv.gz",
    "emar.csv.gz":              "https://physionet.org/files/mimiciv/2.2/hosp/emar.csv.gz",
    "emar_detail.csv.gz":       "https://physionet.org/files/mimiciv/2.2/hosp/emar_detail.csv.gz",
}

local_icu = "/content/mimiciv/icu"
local_hosp = "/content/mimiciv/hosp"
os.makedirs(local_icu, exist_ok=True)
os.makedirs(local_hosp, exist_ok=True)
print("Local download folders prepared.")

def download_files(file_dict, out_dir):
    for name, url in file_dict.items():
        print(f"\nDownloading {name} ...")
        cmd = f"wget --progress=bar:force -c -O {out_dir}/{name} {url}"
        os.system(cmd)
        print(f"   ✔ Completed: {name}")

print("\n=== Downloading ICU files ===")
download_files(icu_files, local_icu)

print("\n=== Downloading HOSP files ===")
download_files(hosp_files, local_hosp)

print("\nAll downloads completed successfully!")

drive_root = "/content/drive/MyDrive/mimiciv"
drive_icu  = f"{drive_root}/icu"
drive_hosp = f"{drive_root}/hosp"
os.makedirs(drive_icu, exist_ok=True)
os.makedirs(drive_hosp, exist_ok=True)

print("\nCopying results to Google Drive... (this may take 1–3 minutes)")
shutil.copytree(local_icu, drive_icu, dirs_exist_ok=True)
shutil.copytree(local_hosp, drive_hosp, dirs_exist_ok=True)

print("\nFiles copied to Google Drive at:")
print("   /content/drive/MyDrive/mimiciv/icu")
print("   /content/drive/MyDrive/mimiciv/hosp")

   ✔ Completed: inputevents.csv.gz

Downloading procedureevents.csv.gz ...
   ✔ Completed: procedureevents.csv.gz

Downloading outputevents.csv.gz ...
   ✔ Completed: outputevents.csv.gz

Downloading d_items.csv.gz ...
   ✔ Completed: d_items.csv.gz

=== Downloading HOSP files ===

Downloading admissions.csv.gz ...
   ✔ Completed: admissions.csv.gz

Downloading patients.csv.gz ...
   ✔ Completed: patients.csv.gz

Downloading labevents.csv.gz ...
   ✔ Completed: labevents.csv.gz

Downloading d_labitems.csv.gz ...
   ✔ Completed: d_labitems.csv.gz

Downloading microbiologyevents.csv.gz ...
   ✔ Completed: microbiologyevents.csv.gz

Downloading emar.csv.gz ...
   ✔ Completed: emar.csv.gz

Downloading emar_detail.csv.gz ...
   ✔ Completed: emar_detail.csv.gz

All downloads completed successfully!

Copying results to Google Drive... (this may take 1–3 minutes)

Files copied to Google Drive at:
   /content/drive/MyDrive/mimiciv/icu
   /content/drive/MyDrive/mimiciv/hosp


In [None]:
!ls -lh /content/mimiciv/hosp/labevents.csv.gz

Sanity Check to confirm presence of downloaded files

In [None]:
import os

required_files = [
    "icu/icustays.csv.gz",
    "icu/inputevents.csv.gz",
    "icu/procedureevents.csv.gz",
    "icu/outputevents.csv.gz",
    "icu/d_items.csv.gz",
    "hosp/admissions.csv.gz",
    "hosp/patients.csv.gz",
    "hosp/labevents.csv.gz",
    "hosp/d_labitems.csv.gz",
]

base = "/content/drive/MyDrive/mimiciv"

for file in required_files:
    path = f"{base}/{file}"
    if os.path.exists(path):
        size_mb = os.path.getsize(path) / (1024 * 1024)
        status = "GOOD" if size_mb > 0.01 else "EMPTY"
        print(f"{status} {file}: {size_mb:.2f} MB")
    else:
        print(f"MISSING: {file}")

In [None]:
import os
path = "/content/drive/MyDrive/mimiciv/icu/inputevents.csv.gz"
if os.path.exists(path):
    size_mb = os.path.getsize(path) / (1024*1024)
    print(f"Current size: {size_mb:.1f} MB / ~2500 MB")

## Step 3: Install PyTorch XLA for TPU Support


In [None]:
!pip install torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

!pip install accelerate transformers hydra-core omegaconf pandas numpy scipy scikit-learn tqdm datasets tokenizers safetensors huggingface-hub

## Step 4: Verify TPU is Available


In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

device = xm.xla_device()
print(f"TPU Device: {device}")

# Use the new API to get world size
try:
    print(f"Number of TPU cores: {xr.world_size()}")
except:
    print(f"Number of devices: {xr.global_runtime_device_count()}")

# Test tensor on TPU
test_tensor = torch.randn(3, 3).to(device)
print(f"Test tensor created on device: {test_tensor.device}")
print("XLA device is ready!")

# Check if it's actually TPU
import os
print(f"\nPJRT_DEVICE: {os.environ.get('PJRT_DEVICE', 'Not set')}")
if 'CPU' in str(device) or os.environ.get('PJRT_DEVICE') == 'CPU':
    print("WARNING: Running on CPU, not TPU!")
    print("   Make sure Runtime → Change runtime type → TPU v2")
else:
    print("TPU is active!")

## Step 5: Clone Repository

In [None]:
%cd /content
!rm -rf labtop-reproduction
!git clone https://github.com/kiotov2/labtop-reproduction.git
%cd labtop-reproduction

## Step 6: Slice MIMIC-IV Data


In [None]:
%cd /content/labtop-reproduction

!python scripts/slice_mimic.py \
    --source /content/drive/MyDrive/mimiciv \
    --dest ./data_small \
    --n_stays 200


## Step 7: Create TPU-Optimized Configs


In [None]:
import os

os.makedirs("src/config/data", exist_ok=True)
os.makedirs("src/config/train", exist_ok=True)

# Data config
with open("src/config/data/mimiciv_small.yaml", "w") as f:
    f.write("""defaults:
  - mimiciv

raw_data_path: /content/labtop-reproduction/data_small
min_los: 1
debug_table_sample_ratio: 1.0
""")

# Train config
with open("src/config/train/train_small_tpu.yaml", "w") as f:
    f.write("""defaults:
  - train_base

epochs: 5
batch_size: 2              # Reduced from 8 to prevent OOM
gradient_accumulation_steps: 8  # Increased to keep effective batch size = 16
use_wandb: false
patience: 1
max_seq_len: 512
lr: 1e-4
""")

print("Configs created for TPU training (MEMORY-OPTIMIZED)")
print("   - Batch size: 2 per core (reduced for memory)")
print("   - Gradient accumulation: 8 steps")
print("   - Effective batch size: 16 (2 × 8)")

In [None]:
# Decompress all the sliced data files
!cd /content/labtop-reproduction/data_small/icu && for f in *.csv.gz; do gunzip -k "$f"; done
!cd /content/labtop-reproduction/data_small/hosp && for f in *.csv.gz; do gunzip -k "$f"; done

print("Files decompressed!")

# Verify
!ls -lh /content/labtop-reproduction/data_small/icu/icustays.csv
!ls -lh /content/labtop-reproduction/data_small/hosp/

In [None]:
import os

# Setup authentication
netrc_path = "/root/.netrc"
with open(netrc_path, "w") as f:
    f.write(f"machine physionet.org login kiotov2 password <password>\n")
    f.write(f"machine content.physionet.org login kiotov2 password <password>\n")
os.chmod(netrc_path, 0o600)

!wget -O /content/labtop-reproduction/data_small/icu/d_items.csv.gz https://physionet.org/files/mimiciv/2.2/icu/d_items.csv.gz

# Decompress it
!gunzip -k /content/labtop-reproduction/data_small/icu/d_items.csv.gz

# Verify
!ls -lh /content/labtop-reproduction/data_small/icu/d_items.csv

print("d_items downloaded and decompressed!")

## Step 8: Preprocess Data


In [None]:
# Decompress d_items specifically
!gunzip -k /content/labtop-reproduction/data_small/icu/d_items.csv.gz

# Verify it's there
!ls -lh /content/labtop-reproduction/data_small/icu/d_items.csv

print("d_items.csv decompressed!")

In [None]:
%cd /content/labtop-reproduction/labtop

!python src/scripts/preprocess.py \
    data=mimiciv_small \
    data_path=/content/labtop-reproduction/data_small \
    data.use_tables="[labevents,inputevents,procedureevents,outputevents]" \
    max_seq_len=512


In [None]:
# Combined check
import os

print("DATA STATUS CHECK\n" + "="*50)

# Check 1: Sliced data
sliced_exists = os.path.exists("/content/labtop-reproduction/data_small/icu") and \
                os.path.exists("/content/labtop-reproduction/data_small/hosp")
print(f"{'GOOD' if sliced_exists else 'BAD'} Sliced data (data_small/)")

# Check 2: Preprocessed data
preprocessed_exists = os.path.exists("/content/labtop-reproduction/data/mimiciv")
if preprocessed_exists:
    subdirs = os.listdir("/content/labtop-reproduction/data/mimiciv")
    has_datasets = any('dataset' in str(os.listdir(f"/content/labtop-reproduction/data/mimiciv/{d}"))
                      for d in subdirs if os.path.isdir(f"/content/labtop-reproduction/data/mimiciv/{d}"))
    preprocessed_exists = has_datasets

print(f"{'GOOD' if preprocessed_exists else 'BAD'} Preprocessed data (data/mimiciv/)")

print("\n" + "="*50)
print("VERDICT:")
if not sliced_exists:
    print("Need to run: SLICING")
if not preprocessed_exists:
    print("Need to run: PREPROCESSING")
if sliced_exists and preprocessed_exists:
    print("ALL DATA READY - Skip to training!")

In [None]:
# Check what's actually in the file
!cat src/config/train/train_small_tpu.yaml

In [None]:
# Patch trainer.py to disable accelerate.save_state on TPU
trainer_path = "/content/labtop-reproduction/labtop/src/core/models/trainer.py"

with open(trainer_path, "r") as f:
    code = f.read()

# Replace the save_state call with a TPU safe version
patched = code.replace(
    "self.accelerator.save_state(self.model_dir)",
    "print('Skipping save_state on TPU (patched).')"
)

with open(trainer_path, "w") as f:
    f.write(patched)

print("Patched trainer.py to skip save_state on TPU")


## Step 9: Train on TPU


In [None]:
%cd /content/labtop-reproduction/labtop

# Clear Hydra cache
!rm -rf .hydra
!rm -rf outputs


print("Cache cleared!")

# Now run with explicit override
import os
data_folder = os.listdir("/content/labtop-reproduction/data/mimiciv")[0]

!python src/scripts/train.py \
    data=mimiciv_small \
    train=train_small_tpu \
    max_seq_len=512 \
    train.epochs=5 \
    data_path=/content/labtop-reproduction/data/mimiciv/{data_folder}

## Step 10: Evaluate


In [None]:
ls -R ./trained_models/mimiciv_labevents_inputevents_procedureevents_outputevents


## Monitor TPU Utilization


In [None]:
import torch_xla.debug.metrics as met

# Print TPU metrics
print(met.metrics_report())
