# LWNet Reproduction - Google Colab

This notebook allows you to reproduce LWNet (The Little W-Net That Could) results using Google Colab's free GPU.

**Paper:** [The Little W-Net That Could: State-of-the-Art Retinal Vessel Segmentation with Minimalistic Models](https://arxiv.org/abs/2009.01907)

**Original Repository:** [agaldran/lwnet](https://github.com/agaldran/lwnet)

## Setup Instructions

1. **Enable GPU**: Runtime → Change runtime type → Hardware accelerator → GPU
2. Run cells sequentially from top to bottom
3. Training will use Colab's GPU (Tesla T4 or similar)


## 1. Check GPU Availability

In [11]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("⚠️ GPU not available! Please enable GPU in Runtime → Change runtime type")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4
CUDA version: 12.6


## 2. Clone Your Repository

Replace the repository URL with your repository URL if needed.

In [2]:
!cd .. && rm -rf /content/lwnet

In [None]:
import os

# Clone the main repository
if not os.path.exists('lwnet'):
    # TODO: fix code to be compatible with conda, push to personal repo and clone
    !git clone --recurse-submodules https://github.com/agaldran/lwnet.git
    print("✓ Repository cloned successfully")
else:
    print("✓ Repository already exists")

# Navigate to lwnet directory
%cd lwnet

Cloning into 'lwnet'...
remote: Enumerating objects: 1198, done.[K
remote: Counting objects: 100% (346/346), done.[K
remote: Compressing objects: 100% (278/278), done.[K
remote: Total 1198 (delta 100), reused 309 (delta 66), pack-reused 852 (from 1)[K
Receiving objects: 100% (1198/1198), 22.42 MiB | 25.25 MiB/s, done.
Resolving deltas: 100% (603/603), done.
✓ Repository cloned successfully
/content/lwnet


## 3. Install Dependencies

Installing required packages for LWNet.

In [None]:
# Install Conda (Miniconda)
!wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local

# Add Conda to path
# import sys
# sys.path.append('/usr/local/lib/python3.8/site-packages')

!conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main
!conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r

# Create environment and install dependencies based on LWNet guide
# Note: Adjust python version if specific requirement exists in guide
# !conda env remove --name lwnet -y
!conda create -n lwnet --file environment.txt -y
# !source activate lwnet && conda install pytorch=1.8.0 torchvision=0.9.0 -c pytorch -y

print("✓ Conda installed and environment created")

--2025-11-26 17:27:46--  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.32.241, 104.16.191.158, 2606:4700::6810:bf9e, ...
Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.32.241|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 157891003 (151M) [application/octet-stream]
Saving to: ‘Miniconda3-latest-Linux-x86_64.sh’


2025-11-26 17:27:46 (230 MB/s) - ‘Miniconda3-latest-Linux-x86_64.sh’ saved [157891003/157891003]

PREFIX=/usr/local
Unpacking bootstrapper...
Unpacking payload...

Installing base environment...

Preparing transaction: ...working... done
Executing transaction: ...working... done
installation finished.
    You currently have a PYTHONPATH environment variable set. This may cause
    unexpected behavior when running the Python interpreter in Miniconda3.
    For best results, please verify that your PYTHONPATH only points to
    directories of packages that are 

## 4. Download Public Datasets

This will download 7 public datasets (DRIVE, CHASE-DB, HRF, STARE, IOSTAR, ARIA, RC-SLO).

**Note:** This may take 10-15 minutes depending on your connection.

In [8]:
# Download datasets
!source activate lwnet && python get_public_data.py

# Verify datasets
!ls -la data/

downloading data
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 91.2M    0 91.2M    0     0  21.4M      0 --:--:--  0:00:04 --:--:-- 21.4M
--2025-11-26 17:32:21--  http://webeye.ophth.uiowa.edu/abramoff/AV_groundTruth.zip
Resolving webeye.ophth.uiowa.edu (webeye.ophth.uiowa.edu)... 129.255.116.103
Connecting to webeye.ophth.uiowa.edu (webeye.ophth.uiowa.edu)|129.255.116.103|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30174977 (29M) [application/x-zip-compressed]
Saving to: ‘AV_groundTruth.zip’


2025-11-26 17:32:32 (2.61 MB/s) - ‘AV_groundTruth.zip’ saved [30174977/30174977]

Archive:  AV_groundTruth.zip
   creating: data/DRIVE/AV_groundTruth/
   creating: data/DRIVE/AV_groundTruth/test/
   creating: data/DRIVE/AV_groundTruth/test/vessel/
  inflating: data/DRIVE/AV_groundTruth/test/vessel/01_test.png  
  inflating: data/DRIVE/AV_groundTruth/tes

## 5. Train Model on DRIVE Dataset

Training a W-Net model on the DRIVE dataset.

**Training time:** ~20-30 minutes on Colab GPU

In [None]:
# !pip install -q torch torchvision numpy pandas matplotlib Pillow scikit-image scikit-learn
# !pip show torch torchvision

In [None]:
!source activate lwnet && python train_cyclical.py \
    --csv_train data/DRIVE/train.csv \
    --cycle_lens 20/50 \
    --model_name wnet \
    --save_path wnet_drive \
    --device cuda:0

  import pkg_resources
* Training on device 
* Creating Dataloaders, batch size = 4, workers = 0
* Instantiating a wnet model
Total params: 68,482
* Instantiating loss function BCEWithLogitsLoss()
* Starting to train
 ----------
Cycle 1/20
  0% 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
  File [35m"/content/lwnet/train_cyclical.py"[0m, line [35m310[0m, in [35m<module>[0m
    m1, m2, m3=[31mtrain_model[0m[1;31m(model, optimizer, criterion, train_loader, val_loader, scheduler, grad_acc_steps, metric, experiment_path)[0m
               [31m~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/content/lwnet/train_cyclical.py"[0m, line [35m160[0m, in [35mtrain_model[0m
    tr_logits, tr_labels, tr_loss = [31mtrain_one_cycle[0m[1;31m(train_loader, model, criterion, optimizer, scheduler, grad_acc_steps, cycle)[0m
                                    [31m~~~~~~~~~~~~~

## 6. Generate Predictions

Generate segmentation predictions on the DRIVE test set.

In [None]:
!source activate lwnet && python generate_results.py \
    --config_file experiments/wnet_drive/config.cfg \
    --dataset DRIVE \
    --device cuda:0

print("\n✓ Predictions generated in results/DRIVE/experiments/wnet_drive")

## 7. Evaluate Performance

Compute performance metrics (AUC, Dice, etc.) on the test set.

In [None]:
!source activate lwnet && python analyze_results.py \
    --path_train_preds results/DRIVE/experiments/wnet_drive \
    --path_test_preds results/DRIVE/experiments/wnet_drive \
    --train_dataset DRIVE \
    --test_dataset DRIVE

Traceback (most recent call last):
  File "analyze_results.py", line 8, in <module>
    import matplotlib.pyplot as plt
  File "/usr/local/envs/lwnet/lib/python3.7/site-packages/matplotlib/pyplot.py", line 2320, in <module>
    switch_backend(rcParams["backend"])
  File "/usr/local/envs/lwnet/lib/python3.7/site-packages/matplotlib/pyplot.py", line 260, in switch_backend
    class backend_mod(matplotlib.backend_bases._Backend):
  File "/usr/local/envs/lwnet/lib/python3.7/site-packages/matplotlib/pyplot.py", line 261, in backend_mod
    locals().update(vars(importlib.import_module(backend_name)))
  File "/usr/local/envs/lwnet/lib/python3.7/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
ModuleNotFoundError: No module named 'matplotlib_inline'


## 8. Visualize Results (Optional)

Display some predictions alongside ground truth.

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import glob

# Get a sample image
pred_files = sorted(glob.glob('results/DRIVE/experiments/wnet_drive/*.png'))

if pred_files:
    sample_file = pred_files[0]
    sample_name = os.path.basename(sample_file).replace('.png', '')
    
    # Load images
    img_path = f'data/DRIVE/images/{sample_name}.png'
    gt_path = f'data/DRIVE/manual/{sample_name}.png'
    pred_path = sample_file
    
    if os.path.exists(img_path) and os.path.exists(gt_path):
        img = Image.open(img_path)
        gt = Image.open(gt_path)
        pred = Image.open(pred_path)
        
        # Display
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        axes[1].imshow(gt, cmap='gray')
        axes[1].set_title('Ground Truth')
        axes[1].axis('off')
        
        axes[2].imshow(pred, cmap='gray')
        axes[2].set_title('Prediction')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()
else:
    print("No prediction files found")

## 9. Download Results

Compress and download the trained model and results.

In [None]:
# Compress results
!zip -r lwnet_results.zip experiments/wnet_drive results/DRIVE/experiments/wnet_drive

# Download (this will trigger a download in your browser)
from google.colab import files
files.download('lwnet_results.zip')

print("✓ Results compressed and ready for download")

## Additional Experiments

### Train on CHASE-DB

```python
!python train_cyclical.py \
    --csv_train data/CHASEDB/train.csv \
    --cycle_lens 40/50 \
    --model_name wnet \
    --save_path wnet_chasedb \
    --device cuda:0
```

### Cross-Dataset Evaluation

```python
# Generate predictions on CHASE-DB using DRIVE model
!python generate_results.py \
    --config_file experiments/wnet_drive/config.cfg \
    --dataset CHASEDB \
    --device cuda:0

# Evaluate cross-dataset performance
!python analyze_results.py \
    --path_train_preds results/DRIVE/experiments/wnet_drive \
    --path_test_preds results/CHASEDB/experiments/wnet_drive \
    --train_dataset DRIVE \
    --test_dataset CHASEDB
```

### Train on HRF (Higher Resolution)

```python
!python train_cyclical.py \
    --csv_train data/HRF/train.csv \
    --cycle_lens 30/50 \
    --model_name wnet \
    --save_path wnet_hrf_1024 \
    --im_size 1024 \
    --batch_size 2 \
    --grad_acc_steps 1 \
    --device cuda:0
```

## Notes

- **Runtime Limits:** Colab has session time limits. For long training, consider Colab Pro or save checkpoints regularly.
- **Storage:** Colab provides limited storage. Clean up datasets/results if needed.
- **GPU Memory:** If you encounter OOM errors, reduce batch size or image size.
- **Persistence:** Files in Colab are temporary. Download important results before session ends.

## References

```
The Little W-Net That Could: State-of-the-Art Retinal Vessel Segmentation with Minimalistic Models
Adrian Galdran, André Anjos, Jose Dolz, Hadi Chakor, Hervé Lombaert, Ismail Ben Ayed
https://arxiv.org/abs/2009.01907, Sep. 2020
```
