In [1]:
import dataclasses
import glob
import json
import logging
import os
import pickle
from functools import partial
from logging.config import valid_ident
from os.path import join

from tqdm import tqdm
import flax.linen as nn
from flax.core.frozen_dict import freeze,unfreeze
from flax.training.train_state import TrainState
from jax.config import config
import jax
import jax.numpy as jnp
import numpy as np
import optax

import lacss
from lacss.utils import show_images

import tensorflow as tf
tf.config.set_visible_devices([], "GPU")

#########################################

datapath: str = "../../../a431"
transfer: str = '../../runs/supervised/20230307-164258/convnext_p4.pkl'
logpath: str = "../../runs/a431_train/"
seed: int = 42
batchsize: int = 1
n_epochs: int = 15
warmup_epochs: int = 3
steps_per_epoch: int = 3000
init_epoch: int = 0
size_loss: float = 0.01
n_buckets: int = 4

In [2]:
import time
log_sub_dir = time.strftime("%Y%m%d-%H%M")
logpath = join(logpath, log_sub_dir)
logpath

'../../runs/a431_train/20230324-1150'

In [3]:
def train_parser_semisupervised(inputs):
    inputs = lacss.data.parse_train_data_func(
        inputs, size_jitter=(0.85, 1.15)
    )

    image = inputs["image"]
    image = tf.image.random_contrast(image, 0.6, 1.4)
    image = tf.image.random_brightness(image, 0.3)

    gt_locations = inputs["locations"]

    x_data = dict(image=image, gt_locations=gt_locations)

    return x_data


def val_parser(inputs):
    return (
        dict(
            image=inputs["image"],
        ),
        dict(
            gt_boxes=inputs["bboxes"],
            gt_locations=inputs["locations"],
            gt_labels=inputs['label'],
        ),
    )

ds_train = lacss.data.dataset_from_simple_annotations(
    join(datapath, "train.json"),
    join(datapath, 'train'),
    [512,512,1],
)

ds_train = ds_train.repeat().map(train_parser_semisupervised)
ds_train = ds_train.filter(lambda x: tf.size(x['gt_locations']) > 0)
ds_train = ds_train.bucket_by_sequence_length(
    element_length_func=lambda x: tf.shape(x["gt_locations"])[0],
    bucket_boundaries=list(np.arange(1, n_buckets) * (2056 // n_buckets) + 1),
    bucket_batch_sizes=(batchsize,) * n_buckets,
    padding_values=-1.0,
    pad_to_bucket_boundary=True,
)

imglist = glob.glob(join(datapath, 'test', 'img*'))
masklist = glob.glob(join(datapath, 'test', 'mask*'))
imglist.sort()
masklist.sort()
ds_val = lacss.data.dataset_from_img_mask_pairs(imglist, masklist, [512,512,1])
ds_val = ds_val.map(val_parser).batch(1)

train_data = lacss.train.TFDatasetAdapter(ds_train, steps=-1).get_dataset()
val_data = lacss.train.TFDatasetAdapter(ds_val).get_dataset()

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [4]:
with open(transfer, 'rb') as f:
    module, params = pickle.load(f)

cfg = dict(cfg=dataclasses.asdict(module), aux_edge_cfg={}, aux_fg_cfg={})
model = lacss.modules.lacss.LacssWithHelper(**cfg)
losses = [
    lacss.losses.LPNLoss(),
    lacss.losses.AuxEdgeLoss(),
    lacss.losses.InstanceOverlapLoss(),
    lacss.losses.AuxSizeLoss(size_loss),
    lacss.losses.AuxSegLoss(),
]

trainer = lacss.train.Trainer(
    model=model,
    optimizer = optax.adamw(0.001),
    losses=losses,
    seed=seed,
    strategy=lacss.train.strategy.VMapped,
)

trainer.initialize(val_data)

new_params = unfreeze(trainer.params)
new_params.update(dict(_lacss=params))
trainer.state = trainer.state.replace(params = freeze(new_params))

In [5]:
epoch = init_epoch
pb = tqdm(trainer.train(train_data, rng_cols=["droppath"], training=True))
for steps, logs in enumerate(pb):
    if epoch >= warmup_epochs:
        break

    if (steps + 1) % steps_per_epoch == 0:
        epoch += 1
        print(f"epoch - {epoch}")
        print(", ".join([f"{k}:{v:.4f}" for k, v in logs.items()]))

        trainer.checkpoint(join(logpath, f"cp-{epoch}"))
        trainer.reset()

        val_metrics = [
            lacss.metrics.BoxAP([0.5, 0.75]),
        ]
        var_logs = trainer.test_and_compute(val_data, val_metrics)
        for k,v in var_logs.items():
            print(f"{k}: {v}")

2999it [12:45,  5.08it/s]

epoch - 1
lpnloss:0.1484, aux_edge_loss:0.0125, instance_overlap_loss:0.0593, aux_size_loss:0.5071, aux_seg_loss:0.3634


3001it [18:11, 68.49s/it]

box_ap: [0.79969079 0.25104126]


5999it [26:39,  4.38it/s]

epoch - 2
lpnloss:0.1351, aux_edge_loss:0.0119, instance_overlap_loss:0.0528, aux_size_loss:0.4989, aux_seg_loss:0.3890


6000it [29:05, 43.98s/it]

box_ap: [0.75184496 0.22982508]


8999it [37:40,  4.21it/s]

epoch - 3
lpnloss:0.1316, aux_edge_loss:0.0118, instance_overlap_loss:0.0511, aux_size_loss:0.4963, aux_seg_loss:0.3996


9000it [40:06,  3.74it/s]

box_ap: [0.72635047 0.2301628 ]





In [6]:
trainer.losses = [
    lacss.losses.LPNLoss(),
    lacss.losses.AuxEdgeLoss(),
    lacss.losses.AuxSegLoss(),
    lacss.losses.SelfSupervisedInstanceLoss(ver=2),
    lacss.losses.AuxSizeLoss(size_loss),
]
trainer.reset()

pb = tqdm(trainer.train(train_data, rng_cols=["droppath"], training=True))
for steps, logs in enumerate(pb):
    if epoch >= n_epochs:
        break

    if (steps + 1) % steps_per_epoch == 0:
        epoch += 1
        print(f"epoch - {epoch}")
        print(", ".join([f"{k}:{v:.4f}" for k, v in logs.items()]))

        trainer.checkpoint(join(logpath, f"cp-{epoch}"))
        trainer.reset()

        val_metrics = [
            lacss.metrics.MaskAP([0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]),
            # lacss.metrics.LoiAP([0.2, 0.5, 1.0]),
            lacss.metrics.BoxAP([0.5, 0.75]),
        ]
        var_logs = trainer.test_and_compute(val_data, val_metrics)
        for k,v in var_logs.items():
            print(f"{k}: {v}")


2999it [12:56,  4.39it/s]

epoch - 4
lpnloss:0.1293, aux_edge_loss:0.0115, aux_seg_loss:0.2927, self_supervised_instance_loss:0.8574, aux_size_loss:0.4748


3001it [15:22, 30.75s/it]

mask_ap: [7.28442077e-01 6.52729593e-01 5.53512480e-01 4.47436631e-01
 3.18688217e-01 1.70477299e-01 5.07402970e-02 4.85481824e-03
 1.98393802e-05 3.08018523e-08]
box_ap: [0.70729711 0.18329339]


5999it [23:50,  4.13it/s]

epoch - 5
lpnloss:0.1274, aux_edge_loss:0.0113, aux_seg_loss:0.2935, self_supervised_instance_loss:0.8569, aux_size_loss:0.4735


6001it [26:16, 30.85s/it]

mask_ap: [7.27726944e-01 6.50544451e-01 5.60518476e-01 4.53828486e-01
 3.29050231e-01 1.83473393e-01 5.73567111e-02 4.58546716e-03
 1.57186339e-05 2.94978986e-08]
box_ap: [0.69781834 0.18571672]


8999it [34:45,  5.80it/s]

epoch - 6
lpnloss:0.1263, aux_edge_loss:0.0114, aux_seg_loss:0.2997, self_supervised_instance_loss:0.8572, aux_size_loss:0.4726


9001it [36:11, 18.09s/it]

mask_ap: [7.05326943e-01 6.32402377e-01 5.45732444e-01 4.32433293e-01
 3.08288495e-01 1.64401068e-01 4.73850784e-02 4.61077866e-03
 2.00809774e-05 3.09858457e-08]
box_ap: [0.68668721 0.16825018]


11999it [44:39,  4.36it/s]

epoch - 7
lpnloss:0.1251, aux_edge_loss:0.0116, aux_seg_loss:0.2988, self_supervised_instance_loss:0.8568, aux_size_loss:0.4714


12001it [47:05, 30.79s/it]

mask_ap: [6.65602905e-01 5.91378026e-01 5.15534028e-01 4.14065342e-01
 2.93778311e-01 1.60895714e-01 5.04911295e-02 4.51283167e-03
 1.38459266e-05 2.96980539e-08]
box_ap: [0.63343889 0.15873369]


13624it [52:37,  4.31it/s]


KeyboardInterrupt: 