Skip to content

Commit

Permalink
Merge pull request #116 from initze/fix/113-wandb
Browse files Browse the repository at this point in the history
Make (simple) wandb logging work again
  • Loading branch information
initze committed Apr 30, 2024
2 parents a8d2593 + e49253e commit 689e55c
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 81 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ dist/
.python-version
requirements.lock
requirements-dev.lock
wandb
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"opencv-python>=4.9.0.80",
"swifter>=1.4.0",
"mkdocs-awesome-pages-plugin>=2.9.2",
"rich"
]
readme = "README.md"
requires-python = ">= 3.10"
Expand Down
121 changes: 68 additions & 53 deletions src/thaw_slump_segmentation/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
Usecase 2 Training Script
"""

import argparse
import re
import subprocess
Expand All @@ -18,33 +19,40 @@
import torch
import torch.nn as nn
import yaml
from rich import pretty, traceback
from tqdm import tqdm

from ..data_loading import get_loader, get_vis_loader, get_slump_loader, DataSources
from ..metrics import Metrics, Accuracy, Precision, Recall, F1, IoU
from ..models import create_model, create_loss
from ..utils import showexample, plot_metrics, plot_precision_recall, init_logging, get_logger, yaml_custom

parser = argparse.ArgumentParser(description='Training script',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-s', '--summary', action='store_true',
help='Only print model summary and return.')
parser.add_argument("--data_dir", default='data', type=Path, help="Path to data processing dir")
parser.add_argument("--log_dir", default='logs', type=Path, help="Path to log dir")
parser.add_argument('-n', '--name', default='',
help='Give this run a name, so that it will be logged into logs/<NAME>_<timestamp>.')
parser.add_argument('-c', '--config', default='config.yml', type=Path,
help='Specify run config to use.')
parser.add_argument('-r', '--resume', default='',
help='Resume from the specified checkpoint.'
'Can be either a run-id (e.g. "2020-06-29_18-12-03") to select the last'
'checkpoint of that run, or a direct path to a checkpoint to be loaded.'
'Overrides the resume option in the config file if given.'
)
parser.add_argument('-wp', '--wandb_project', default='thaw-slump-segmentation',
help='Set a project name for weights and biases')
parser.add_argument('-wn', '--wandb_name', default=None,
help='Set a run name for weights and biases')
import wandb

from ..data_loading import DataSources, get_loader, get_slump_loader, get_vis_loader
from ..metrics import F1, Accuracy, IoU, Metrics, Precision, Recall
from ..models import create_loss, create_model
from ..utils import get_logger, init_logging, plot_metrics, plot_precision_recall, showexample, yaml_custom

traceback.install(show_locals=True)
pretty.install()

parser = argparse.ArgumentParser(description='Training script', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-s', '--summary', action='store_true', help='Only print model summary and return.')
parser.add_argument('--data_dir', default='data', type=Path, help='Path to data processing dir')
parser.add_argument('--log_dir', default='logs', type=Path, help='Path to log dir')
parser.add_argument(
'-n', '--name', default='', help='Give this run a name, so that it will be logged into logs/<NAME>_<timestamp>.'
)
parser.add_argument('-c', '--config', default='config.yml', type=Path, help='Specify run config to use.')
parser.add_argument(
'-r',
'--resume',
default='',
help='Resume from the specified checkpoint.'
'Can be either a run-id (e.g. "2020-06-29_18-12-03") to select the last'
'checkpoint of that run, or a direct path to a checkpoint to be loaded.'
'Overrides the resume option in the config file if given.',
)
parser.add_argument(
'-wp', '--wandb_project', default='thaw-slump-segmentation', help='Set a project name for weights and biases'
)
parser.add_argument('-wn', '--wandb_name', default=None, help='Set a run name for weights and biases')


class Engine:
Expand Down Expand Up @@ -73,7 +81,7 @@ def __init__(self):
encoder_name=m['encoder'],
encoder_weights=None if m['encoder_weights'] == 'random' else m['encoder_weights'],
classes=1,
in_channels=m['input_channels']
in_channels=m['input_channels'],
)

# make parallel
Expand All @@ -94,7 +102,7 @@ def __init__(self):
self.logger.info(f"Resuming training from checkpoint {self.config['resume']}")
self.model.load_state_dict(torch.load(self.config['resume']))

self.dev = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")
self.dev = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
self.logger.info(f'Training on {self.dev} device')

self.model = self.model.to(self.dev)
Expand All @@ -109,21 +117,23 @@ def __init__(self):

if args.summary:
from torchsummary import summary

summary(self.model, [(self.config['model']['input_channels'], 256, 256)])
sys.exit(0)

self.dataset_cache = {}

self.vis_predictions = None
self.vis_loader, self.vis_names = get_vis_loader(self.config['visualization_tiles'],
batch_size=self.config['batch_size'],
data_sources=self.data_sources,
data_root=self.DATA_ROOT)
self.vis_loader, self.vis_names = get_vis_loader(
self.config['visualization_tiles'],
batch_size=self.config['batch_size'],
data_sources=self.data_sources,
data_root=self.DATA_ROOT,
)

# Write the config YML to the run-folder
self.config['run_info'] = dict(
timestamp=timestamp,
git_head=subprocess.check_output(["git", "describe"], encoding='utf8').strip()
timestamp=timestamp, git_head=subprocess.check_output(['git', 'describe'], encoding='utf8').strip()
)
with open(self.log_dir / 'config.yml', 'w') as f:
yaml.dump(self.config, f)
Expand All @@ -134,7 +144,11 @@ def __init__(self):
# Metrics and Weights and Biases initialization
self.trn_metrics = {}
self.val_metrics = {}
wandb.init(project=args.wandb_project, name=args.wandb_name, config=self.config, entity='ml4earth')
print('wandb project:', args.wandb_project)
print('wandb name:', args.wandb_name)
print('config:', self.config)
print('entity:', 'ml4earth')
wandb.init(project=args.wandb_project, name=args.wandb_name, config=self.config, entity='ingmarnitze_team')

def run(self):
for phase in self.config['schedule']:
Expand All @@ -144,9 +158,9 @@ def run(self):
self.loss_function = create_loss(scoped_get('loss_function', phase, self.config)).to(self.dev)

for step in phase['steps']:
if type(step) is dict:
if isinstance(step, dict):
assert len(step) == 1
(command, key), = step.items()
((command, key),) = step.items()
else:
command = step

Expand All @@ -161,9 +175,9 @@ def run(self):
elif command == 'log_images':
self.log_images()
if self.scheduler:
print("before step:", self.scheduler.get_last_lr())
print('before step:', self.scheduler.get_last_lr())
self.scheduler.step()
print("after step:", self.scheduler.get_last_lr())
print('after step:', self.scheduler.get_last_lr())

def get_dataloader(self, name):
if name in self.dataset_cache:
Expand Down Expand Up @@ -207,15 +221,15 @@ def train_epoch(self, train_loader):

metrics_terms = {}
if isinstance(y_hat, (tuple, list)):
# Deep Supervision
deep_super_losses = [self.loss_function(pred.squeeze(1), target) for pred in y_hat]
y_hat = y_hat[0].squeeze(1)
loss = sum(deep_super_losses)
metrics_terms['Loss'] = deep_super_losses[0].detach()
metrics_terms['Deep Supervision Loss'] = loss.detach()
# Deep Supervision
deep_super_losses = [self.loss_function(pred.squeeze(1), target) for pred in y_hat]
y_hat = y_hat[0].squeeze(1)
loss = sum(deep_super_losses)
metrics_terms['Loss'] = deep_super_losses[0].detach()
metrics_terms['Deep Supervision Loss'] = loss.detach()
else:
loss = self.loss_function(y_hat, target)
metrics_terms['Loss'] = loss.detach()
loss = self.loss_function(y_hat, target)
metrics_terms['Loss'] = loss.detach()

loss.backward()
self.opt.step()
Expand Down Expand Up @@ -280,7 +294,7 @@ def val_epoch(self, val_loader):
safe_append(self.val_metrics, key, val)
safe_append(self.val_metrics, 'step', self.board_idx)
safe_append(self.val_metrics, 'epoch', self.epoch)
wandb.log({'val/{k}': v for k, v in metrics_vals.items()}, step=self.board_idx)
wandb.log({'val/{k}': v for k, v in self.val_metrics.items()}, step=self.board_idx)

def log_images(self):
self.logger.debug(f'Epoch {self.epoch} - Image Logging')
Expand All @@ -296,8 +310,9 @@ def log_images(self):
(self.log_dir / 'tile_predictions').mkdir(exist_ok=True)
for i, tile in enumerate(self.vis_names):
filename = self.log_dir / 'tile_predictions' / f'{tile}.jpg'
showexample(self.vis_loader.dataset[i], self.vis_predictions[i],
filename, self.data_sources, step=self.board_idx)
showexample(
self.vis_loader.dataset[i], self.vis_predictions[i], filename, self.data_sources, step=self.board_idx
)

outdir = self.log_dir / 'metrics_plots'
outdir.mkdir(exist_ok=True)
Expand All @@ -307,7 +322,7 @@ def log_images(self):
def setup_lr_scheduler(self):
# Scheduler
if 'learning_rate_scheduler' not in self.config.keys():
print("running without learning rate scheduler")
print('running without learning rate scheduler')
self.scheduler = None
elif self.config['learning_rate_scheduler'] == 'StepLR':
if 'lr_step_size' not in self.config.keys():
Expand Down Expand Up @@ -345,8 +360,8 @@ def safe_append(dictionary, key, value):


def main():
args = parser.parse_args()
Engine().run()

if __name__ == "__main__":
main()

if __name__ == '__main__':
main()
59 changes: 31 additions & 28 deletions src/thaw_slump_segmentation/utils/plot_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,39 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os

import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
from matplotlib.ticker import MaxNLocator
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import MaxNLocator

FLATUI = {'Turquoise': (0.10196078431372549, 0.7372549019607844, 0.611764705882353),
'Emerald': (0.1803921568627451, 0.8, 0.44313725490196076),
'Peter River': (0.20392156862745098, 0.596078431372549, 0.8588235294117647),
'Amethyst': (0.6078431372549019, 0.34901960784313724, 0.7137254901960784),
'Wet Asphalt': (0.20392156862745098, 0.28627450980392155, 0.3686274509803922),
'Green Sea': (0.08627450980392157, 0.6274509803921569, 0.5215686274509804),
'Nephritis': (0.15294117647058825, 0.6823529411764706, 0.3764705882352941),
'Belize Hole': (0.1607843137254902, 0.5019607843137255, 0.7254901960784313),
'Wisteria': (0.5568627450980392, 0.26666666666666666, 0.6784313725490196),
'Midnight Blue': (0.17254901960784313, 0.24313725490196078, 0.3137254901960784),
'Sun Flower': (0.9450980392156862, 0.7686274509803922, 0.058823529411764705),
'Carrot': (0.9019607843137255, 0.49411764705882355, 0.13333333333333333),
'Alizarin': (0.9058823529411765, 0.2980392156862745, 0.23529411764705882),
'Clouds': (0.9254901960784314, 0.9411764705882353, 0.9450980392156862),
'Concrete': (0.5843137254901961, 0.6470588235294118, 0.6509803921568628),
'Orange': (0.9529411764705882, 0.611764705882353, 0.07058823529411765),
'Pumpkin': (0.8274509803921568, 0.32941176470588235, 0.0),
'Pomegranate': (0.7529411764705882, 0.2235294117647059, 0.16862745098039217),
'Silver': (0.7411764705882353, 0.7647058823529411, 0.7803921568627451),
'Asbestos': (0.4980392156862745, 0.5490196078431373, 0.5529411764705883)}
import wandb

FLATUI = {
'Turquoise': (0.10196078431372549, 0.7372549019607844, 0.611764705882353),
'Emerald': (0.1803921568627451, 0.8, 0.44313725490196076),
'Peter River': (0.20392156862745098, 0.596078431372549, 0.8588235294117647),
'Amethyst': (0.6078431372549019, 0.34901960784313724, 0.7137254901960784),
'Wet Asphalt': (0.20392156862745098, 0.28627450980392155, 0.3686274509803922),
'Green Sea': (0.08627450980392157, 0.6274509803921569, 0.5215686274509804),
'Nephritis': (0.15294117647058825, 0.6823529411764706, 0.3764705882352941),
'Belize Hole': (0.1607843137254902, 0.5019607843137255, 0.7254901960784313),
'Wisteria': (0.5568627450980392, 0.26666666666666666, 0.6784313725490196),
'Midnight Blue': (0.17254901960784313, 0.24313725490196078, 0.3137254901960784),
'Sun Flower': (0.9450980392156862, 0.7686274509803922, 0.058823529411764705),
'Carrot': (0.9019607843137255, 0.49411764705882355, 0.13333333333333333),
'Alizarin': (0.9058823529411765, 0.2980392156862745, 0.23529411764705882),
'Clouds': (0.9254901960784314, 0.9411764705882353, 0.9450980392156862),
'Concrete': (0.5843137254901961, 0.6470588235294118, 0.6509803921568628),
'Orange': (0.9529411764705882, 0.611764705882353, 0.07058823529411765),
'Pumpkin': (0.8274509803921568, 0.32941176470588235, 0.0),
'Pomegranate': (0.7529411764705882, 0.2235294117647059, 0.16862745098039217),
'Silver': (0.7411764705882353, 0.7647058823529411, 0.7803921568627451),
'Asbestos': (0.4980392156862745, 0.5490196078431373, 0.5529411764705883),
}


def flatui_cmap(*colors):
Expand Down Expand Up @@ -62,8 +67,7 @@ def showexample(data, preds, filename, data_sources, step):
# First plot
ROWS = 6
m = 0.02
gridspec_kw = dict(left=m, right=1 - m, top=1 - m, bottom=m,
hspace=0.12, wspace=m)
gridspec_kw = dict(left=m, right=1 - m, top=1 - m, bottom=m, hspace=0.12, wspace=m)
N = 1 + int(np.ceil(len(preds) / ROWS))
fig, ax = plt.subplots(ROWS, N, figsize=(3 * N, 3 * ROWS), gridspec_kw=gridspec_kw)
ax = ax.T.reshape(-1)
Expand Down Expand Up @@ -133,7 +137,7 @@ def showexample(data, preds, filename, data_sources, step):
ax[2].imshow(pred, **heatmap_args)
ax[2].set_title('Prediction')
for axis in ax:
axis.axis('off')
axis.axis('off')
wandb.log({filename.stem: wandb.Image(fig)}, step=step)


Expand All @@ -151,8 +155,7 @@ def read_metrics_file(file_path):

data.append([int(epoch.replace('Epoch', '')), str(val_type), *acc_vals])

df = pd.DataFrame(columns=['epoch', 'val_type', 'accuracy', 'precision', 'recall', 'f1', 'iou', 'loss'],
data=data)
df = pd.DataFrame(columns=['epoch', 'val_type', 'accuracy', 'precision', 'recall', 'f1', 'iou', 'loss'], data=data)
return df


Expand Down Expand Up @@ -195,6 +198,6 @@ def plot_precision_recall(train_metrics, val_metrics, outdir='.'):

fig.tight_layout()

outfile = os.path.join(outdir, f'precision_recall.png')
outfile = os.path.join(outdir, 'precision_recall.png')
fig.savefig(outfile)
fig.clear()

0 comments on commit 689e55c

Please sign in to comment.