In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
import PIL
from PIL import Image
import requests
from transformers import pipeline
import datasets
from datasets import load_dataset
import os
import evaluate
import torch
import cv2
import json
import codecs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
image_folder = './WE3DS/images/'
annotation_folder = './WE3DS/annotations/segmentation/SegmentationLabel/'
# Define the paths to the images and annotations
all_image_names = np.array(os.listdir(image_folder))

In [3]:
plant_classification = {
    'void': 'void',
    'soil': 'soil',
    'broad bean': 'crop',
    'corn spurry': 'weed',
    'red-root amaranth': 'weed',
    'common buckwheat': 'crop',
    'pea': 'crop',
    'red fingergrass': 'weed',
    'common wild oat': 'weed',
    'cornflower': 'weed',
    'corn cockle': 'weed',
    'corn': 'crop',
    'milk thistle': 'weed',
    'rye brome': 'weed',
    'soybean': 'crop',
    'sunflower': 'crop',
    'narrow-leaved plantain': 'weed',
    'small-flower geranium': 'weed',
    'sugar beet': 'crop'
}

In [4]:
def get_image_meta_filepath(plant_name):
    return './meta/' + plant_name + '_images.json'

In [5]:
def get_image_list_for_crop(crop_name):
    # Create an empty list to store the dataset
    image_list = []
    crop_image_names = json.load(codecs.open(get_image_meta_filepath(crop_name), 'r', 'utf-8-sig'))
    weed_image_names = json.load(codecs.open(get_image_meta_filepath('no_crop'), 'r', 'utf-8-sig'))
    image_names = crop_image_names + weed_image_names

    # Iterate over the image and annotation paths
    for image_name in crop_image_names:
        # Load the image and annotation using PIL
        image = Image.open(image_folder + image_name)
        annotation = Image.open(annotation_folder + image_name)
        
        # Create a dictionary entry for the dataseta
        entry = {'image': image, 'annotation': annotation}
        
        # Add the entry to the dataset
        image_list.append(entry)

    return image_list

In [6]:
def create_and_split_dataset_for_crop(crop_image_list):
    dataset = datasets.Dataset.from_list(crop_image_list)
    dataset = dataset.train_test_split(test_size=0.5)
    train_ds = dataset["train"]
    val_ds, test_ds = dataset["test"].train_test_split(test_size=0.5).values()
    return train_ds, val_ds, test_ds

In [7]:
broad_bean_image_list = get_image_list_for_crop('broad_bean')
broad_bean_train_ds, broad_bean_val_ds, broad_bean_test_ds = create_and_split_dataset_for_crop(broad_bean_image_list)

In [8]:
print("Training subset number of images: " + str(broad_bean_train_ds.num_rows))
print("Validation subset number of images: " + str(broad_bean_val_ds.num_rows))
print("Test subset number of images: " + str(broad_bean_test_ds.num_rows))

Training subset number of images: 105
Validation subset number of images: 52
Test subset number of images: 53


In [9]:
from transformers import AutoImageProcessor

checkpoint = "nvidia/mit-b0"
image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [10]:
def train_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels)
    return inputs

In [11]:
metric = evaluate.load("mean_iou")

In [12]:
with open('./WE3DS/class_names.txt', 'r') as file:
    labels = [line.strip() for line in file]

ids = list(range(1, 20))

id2label = dict(zip(ids, labels))
label2id = dict(zip(labels, ids))

num_labels = len(labels)

print("id2label:", id2label)
print("label2id:", label2id)

id2label: {1: 'void', 2: 'soil', 3: 'broad bean', 4: 'corn spurry', 5: 'red-root amaranth', 6: 'common buckwheat', 7: 'pea', 8: 'red fingergrass', 9: 'common wild oat', 10: 'cornflower', 11: 'corn cockle', 12: 'corn', 13: 'milk thistle', 14: 'rye brome', 15: 'soybean', 16: 'sunflower', 17: 'narrow-leaved plantain', 18: 'small-flower geranium', 19: 'sugar beet'}
label2id: {'void': 1, 'soil': 2, 'broad bean': 3, 'corn spurry': 4, 'red-root amaranth': 5, 'common buckwheat': 6, 'pea': 7, 'red fingergrass': 8, 'common wild oat': 9, 'cornflower': 10, 'corn cockle': 11, 'corn': 12, 'milk thistle': 13, 'rye brome': 14, 'soybean': 15, 'sunflower': 16, 'narrow-leaved plantain': 17, 'small-flower geranium': 18, 'sugar beet': 19}


In [13]:
def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = torch.nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

In [14]:
from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer, EarlyStoppingCallback

model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)

Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.weight', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.1.proj.weight', 'decode_head.classifier.weight', 'decode_head.linear_c.2.pr

In [15]:
training_args = TrainingArguments(
    output_dir="segformer-b0-scene-parse-150",
    learning_rate=6e-5,
    num_train_epochs=10,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
    load_best_model_at_end=True,
)

In [16]:
def initialize_trainer(train_ds, test_ds) :
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    return trainer

In [17]:
broad_bean_train_ds.set_transform(train_transforms)
broad_bean_val_ds.set_transform(train_transforms)
broad_bean_test_ds.set_transform(train_transforms)

In [18]:
braod_bean_trainer = initialize_trainer(broad_bean_train_ds, broad_bean_val_ds)
braod_bean_trainer.train()

  1%|          | 1/110 [00:17<30:55, 17.03s/it]

{'loss': 3.0092, 'learning_rate': 5.945454545454546e-05, 'epoch': 0.09}


  2%|▏         | 2/110 [00:33<29:49, 16.57s/it]

{'loss': 2.9758, 'learning_rate': 5.890909090909091e-05, 'epoch': 0.18}


  3%|▎         | 3/110 [00:49<29:22, 16.47s/it]

{'loss': 2.9369, 'learning_rate': 5.836363636363637e-05, 'epoch': 0.27}


  4%|▎         | 4/110 [01:05<28:53, 16.35s/it]

{'loss': 2.9338, 'learning_rate': 5.781818181818182e-05, 'epoch': 0.36}


  5%|▍         | 5/110 [01:22<28:43, 16.41s/it]

{'loss': 2.8678, 'learning_rate': 5.7272727272727274e-05, 'epoch': 0.45}


  5%|▌         | 6/110 [01:38<28:28, 16.43s/it]

{'loss': 2.8533, 'learning_rate': 5.6727272727272726e-05, 'epoch': 0.55}


  6%|▋         | 7/110 [01:55<28:06, 16.37s/it]

{'loss': 2.8616, 'learning_rate': 5.6181818181818184e-05, 'epoch': 0.64}


  7%|▋         | 8/110 [02:11<27:51, 16.39s/it]

{'loss': 2.8397, 'learning_rate': 5.5636363636363636e-05, 'epoch': 0.73}


  8%|▊         | 9/110 [02:28<27:51, 16.55s/it]

{'loss': 2.7624, 'learning_rate': 5.5090909090909094e-05, 'epoch': 0.82}


  9%|▉         | 10/110 [02:45<27:51, 16.71s/it]

{'loss': 2.7696, 'learning_rate': 5.4545454545454546e-05, 'epoch': 0.91}


 10%|█         | 11/110 [02:54<23:44, 14.39s/it]

{'loss': 2.7335, 'learning_rate': 5.4000000000000005e-05, 'epoch': 1.0}


 11%|█         | 12/110 [03:11<24:32, 15.03s/it]

{'loss': 2.6787, 'learning_rate': 5.3454545454545457e-05, 'epoch': 1.09}


 12%|█▏        | 13/110 [03:28<25:20, 15.67s/it]

{'loss': 2.6233, 'learning_rate': 5.290909090909091e-05, 'epoch': 1.18}


 13%|█▎        | 14/110 [03:44<25:11, 15.75s/it]

{'loss': 2.6291, 'learning_rate': 5.236363636363636e-05, 'epoch': 1.27}


 14%|█▎        | 15/110 [04:00<25:04, 15.83s/it]

{'loss': 2.662, 'learning_rate': 5.181818181818182e-05, 'epoch': 1.36}


 15%|█▍        | 16/110 [04:16<24:52, 15.87s/it]

{'loss': 2.6556, 'learning_rate': 5.127272727272727e-05, 'epoch': 1.45}


 15%|█▌        | 17/110 [04:32<24:43, 15.95s/it]

{'loss': 2.553, 'learning_rate': 5.072727272727273e-05, 'epoch': 1.55}


 16%|█▋        | 18/110 [04:48<24:26, 15.94s/it]

{'loss': 2.503, 'learning_rate': 5.018181818181818e-05, 'epoch': 1.64}


 17%|█▋        | 19/110 [05:04<24:22, 16.07s/it]

{'loss': 2.4517, 'learning_rate': 4.963636363636364e-05, 'epoch': 1.73}


 18%|█▊        | 20/110 [05:20<24:14, 16.17s/it]

{'loss': 2.4555, 'learning_rate': 4.90909090909091e-05, 'epoch': 1.82}


  acc = total_area_intersect / total_area_label
                                                
 18%|█▊        | 20/110 [07:58<24:14, 16.17s/it]

{'eval_loss': 2.826364040374756, 'eval_mean_iou': 0.06909370840388247, 'eval_mean_accuracy': 0.9717052117918751, 'eval_overall_accuracy': 0.9458694730836149, 'eval_per_category_iou': [0.9445901064015024, 0.36819035327226474, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'eval_per_category_accuracy': [0.9446111806927062, 0.9987992428910439, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], 'eval_runtime': 157.7162, 'eval_samples_per_second': 0.33, 'eval_steps_per_second': 0.038, 'epoch': 1.82}


 19%|█▉        | 21/110 [08:16<1:34:54, 63.99s/it]

{'loss': 2.4991, 'learning_rate': 4.854545454545455e-05, 'epoch': 1.91}


 20%|██        | 22/110 [08:25<1:09:52, 47.65s/it]

{'loss': 2.4588, 'learning_rate': 4.8e-05, 'epoch': 2.0}


 21%|██        | 23/110 [08:43<55:54, 38.55s/it]  

{'loss': 2.4628, 'learning_rate': 4.745454545454545e-05, 'epoch': 2.09}


 22%|██▏       | 24/110 [09:00<46:10, 32.21s/it]

{'loss': 2.4627, 'learning_rate': 4.690909090909091e-05, 'epoch': 2.18}


 23%|██▎       | 25/110 [09:18<39:19, 27.76s/it]

{'loss': 2.4227, 'learning_rate': 4.636363636363636e-05, 'epoch': 2.27}


 24%|██▎       | 26/110 [09:35<34:29, 24.64s/it]

{'loss': 2.3346, 'learning_rate': 4.581818181818182e-05, 'epoch': 2.36}


 25%|██▍       | 27/110 [09:52<30:59, 22.40s/it]

{'loss': 2.3522, 'learning_rate': 4.5272727272727274e-05, 'epoch': 2.45}


 25%|██▌       | 28/110 [10:10<28:36, 20.94s/it]

{'loss': 2.3409, 'learning_rate': 4.472727272727273e-05, 'epoch': 2.55}


 26%|██▋       | 29/110 [10:27<26:53, 19.92s/it]

{'loss': 2.3048, 'learning_rate': 4.4181818181818184e-05, 'epoch': 2.64}


 27%|██▋       | 30/110 [10:45<25:39, 19.24s/it]

{'loss': 2.2901, 'learning_rate': 4.3636363636363636e-05, 'epoch': 2.73}


 28%|██▊       | 31/110 [11:02<24:36, 18.69s/it]

{'loss': 2.2801, 'learning_rate': 4.309090909090909e-05, 'epoch': 2.82}


 29%|██▉       | 32/110 [11:20<23:47, 18.30s/it]

{'loss': 2.2288, 'learning_rate': 4.2545454545454546e-05, 'epoch': 2.91}


 30%|███       | 33/110 [11:29<20:08, 15.70s/it]

{'loss': 2.403, 'learning_rate': 4.2e-05, 'epoch': 3.0}


 31%|███       | 34/110 [11:46<20:26, 16.14s/it]

{'loss': 2.3535, 'learning_rate': 4.1454545454545456e-05, 'epoch': 3.09}


 32%|███▏      | 35/110 [12:04<20:44, 16.59s/it]

{'loss': 2.2335, 'learning_rate': 4.090909090909091e-05, 'epoch': 3.18}


 33%|███▎      | 36/110 [12:22<20:48, 16.87s/it]

{'loss': 2.1911, 'learning_rate': 4.0363636363636367e-05, 'epoch': 3.27}


 34%|███▎      | 37/110 [12:39<20:46, 17.08s/it]

{'loss': 2.1927, 'learning_rate': 3.9818181818181825e-05, 'epoch': 3.36}


 35%|███▍      | 38/110 [12:57<20:54, 17.42s/it]

{'loss': 2.1894, 'learning_rate': 3.927272727272728e-05, 'epoch': 3.45}


 35%|███▌      | 39/110 [13:16<20:55, 17.68s/it]

{'loss': 2.1927, 'learning_rate': 3.872727272727273e-05, 'epoch': 3.55}


 36%|███▋      | 40/110 [13:34<20:55, 17.94s/it]

{'loss': 2.1924, 'learning_rate': 3.818181818181818e-05, 'epoch': 3.64}


  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
                                                
 36%|███▋      | 40/110 [16:11<20:55, 17.94s/it]

{'eval_loss': 2.4098501205444336, 'eval_mean_iou': 0.11069646922135404, 'eval_mean_accuracy': 0.9796946095174454, 'eval_overall_accuracy': 0.9727614242309688, 'eval_per_category_iou': [0.9721263196913349, 0.4669277801862676, nan, 0.0, 0.0, 0.0, 0.0, nan, 0.0, nan, 0.0, 0.0, nan, 0.0, nan, 0.0, 0.0, 0.0, nan], 'eval_per_category_accuracy': [0.9724237534413217, 0.986965465593569, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], 'eval_runtime': 156.4486, 'eval_samples_per_second': 0.332, 'eval_steps_per_second': 0.038, 'epoch': 3.64}


 37%|███▋      | 41/110 [16:27<1:13:58, 64.33s/it]

{'loss': 2.1771, 'learning_rate': 3.763636363636364e-05, 'epoch': 3.73}


 38%|███▊      | 42/110 [16:43<56:33, 49.90s/it]  

{'loss': 2.169, 'learning_rate': 3.709090909090909e-05, 'epoch': 3.82}


 39%|███▉      | 43/110 [16:59<44:27, 39.82s/it]

{'loss': 2.2428, 'learning_rate': 3.654545454545455e-05, 'epoch': 3.91}


 40%|████      | 44/110 [17:08<33:39, 30.60s/it]

{'loss': 2.0155, 'learning_rate': 3.6e-05, 'epoch': 4.0}


 41%|████      | 45/110 [17:25<28:42, 26.50s/it]

{'loss': 2.1332, 'learning_rate': 3.545454545454546e-05, 'epoch': 4.09}


 42%|████▏     | 46/110 [17:41<24:55, 23.37s/it]

{'loss': 2.0594, 'learning_rate': 3.490909090909091e-05, 'epoch': 4.18}


 43%|████▎     | 47/110 [17:57<22:13, 21.16s/it]

{'loss': 2.2001, 'learning_rate': 3.436363636363636e-05, 'epoch': 4.27}


 44%|████▎     | 48/110 [18:14<20:22, 19.72s/it]

{'loss': 2.0468, 'learning_rate': 3.3818181818181815e-05, 'epoch': 4.36}


 45%|████▍     | 49/110 [18:30<19:02, 18.73s/it]

{'loss': 2.0688, 'learning_rate': 3.327272727272727e-05, 'epoch': 4.45}


 45%|████▌     | 50/110 [18:47<18:05, 18.09s/it]

{'loss': 2.0312, 'learning_rate': 3.2727272727272725e-05, 'epoch': 4.55}


 46%|████▋     | 51/110 [19:04<17:33, 17.85s/it]

{'loss': 2.2186, 'learning_rate': 3.2181818181818184e-05, 'epoch': 4.64}


 47%|████▋     | 52/110 [19:21<17:03, 17.65s/it]

{'loss': 2.0978, 'learning_rate': 3.1636363636363635e-05, 'epoch': 4.73}


 48%|████▊     | 53/110 [19:38<16:39, 17.53s/it]

{'loss': 2.1335, 'learning_rate': 3.1090909090909094e-05, 'epoch': 4.82}


 49%|████▉     | 54/110 [19:56<16:21, 17.53s/it]

{'loss': 2.0211, 'learning_rate': 3.0545454545454546e-05, 'epoch': 4.91}


 50%|█████     | 55/110 [20:05<13:43, 14.96s/it]

{'loss': 1.9533, 'learning_rate': 3e-05, 'epoch': 5.0}


 51%|█████     | 56/110 [20:23<14:11, 15.78s/it]

{'loss': 2.1849, 'learning_rate': 2.9454545454545456e-05, 'epoch': 5.09}


 52%|█████▏    | 57/110 [20:40<14:27, 16.37s/it]

{'loss': 2.11, 'learning_rate': 2.890909090909091e-05, 'epoch': 5.18}


 53%|█████▎    | 58/110 [20:58<14:25, 16.65s/it]

{'loss': 2.0855, 'learning_rate': 2.8363636363636363e-05, 'epoch': 5.27}


 54%|█████▎    | 59/110 [21:15<14:19, 16.86s/it]

{'loss': 2.0442, 'learning_rate': 2.7818181818181818e-05, 'epoch': 5.36}


 55%|█████▍    | 60/110 [21:33<14:11, 17.04s/it]

{'loss': 1.9896, 'learning_rate': 2.7272727272727273e-05, 'epoch': 5.45}


  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
                                                
 55%|█████▍    | 60/110 [23:55<14:11, 17.04s/it]

{'eval_loss': 1.7253865003585815, 'eval_mean_iou': 0.24895879640159216, 'eval_mean_accuracy': 0.9764682872797605, 'eval_overall_accuracy': 0.9785232147511229, 'eval_per_category_iou': [0.9780261417270578, 0.5157266366824951, nan, 0.0, nan, 0.0, nan, nan, nan, nan, nan, 0.0, nan, nan, nan, nan, 0.0, nan, nan], 'eval_per_category_accuracy': [0.9786232970297944, 0.9743132775297266, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], 'eval_runtime': 142.1735, 'eval_samples_per_second': 0.366, 'eval_steps_per_second': 0.042, 'epoch': 5.45}


 55%|█████▌    | 61/110 [24:13<48:58, 59.96s/it]

{'loss': 2.0211, 'learning_rate': 2.6727272727272728e-05, 'epoch': 5.55}


 56%|█████▋    | 62/110 [24:30<37:44, 47.17s/it]

{'loss': 2.0315, 'learning_rate': 2.618181818181818e-05, 'epoch': 5.64}


 57%|█████▋    | 63/110 [24:48<30:01, 38.33s/it]

{'loss': 1.9932, 'learning_rate': 2.5636363636363635e-05, 'epoch': 5.73}


 58%|█████▊    | 64/110 [25:05<24:39, 32.15s/it]

{'loss': 1.9858, 'learning_rate': 2.509090909090909e-05, 'epoch': 5.82}


 59%|█████▉    | 65/110 [25:23<20:48, 27.75s/it]

{'loss': 2.0951, 'learning_rate': 2.454545454545455e-05, 'epoch': 5.91}


 60%|██████    | 66/110 [25:32<16:14, 22.14s/it]

{'loss': 2.0924, 'learning_rate': 2.4e-05, 'epoch': 6.0}


 61%|██████    | 67/110 [25:50<14:54, 20.81s/it]

{'loss': 1.9306, 'learning_rate': 2.3454545454545456e-05, 'epoch': 6.09}


 62%|██████▏   | 68/110 [26:07<13:50, 19.77s/it]

{'loss': 1.9345, 'learning_rate': 2.290909090909091e-05, 'epoch': 6.18}


 63%|██████▎   | 69/110 [26:25<13:04, 19.13s/it]

{'loss': 1.9788, 'learning_rate': 2.2363636363636366e-05, 'epoch': 6.27}


 64%|██████▎   | 70/110 [26:42<12:28, 18.72s/it]

{'loss': 1.9495, 'learning_rate': 2.1818181818181818e-05, 'epoch': 6.36}


 65%|██████▍   | 71/110 [27:00<11:52, 18.27s/it]

{'loss': 2.3299, 'learning_rate': 2.1272727272727273e-05, 'epoch': 6.45}


 65%|██████▌   | 72/110 [27:17<11:25, 18.05s/it]

{'loss': 1.9251, 'learning_rate': 2.0727272727272728e-05, 'epoch': 6.55}


 66%|██████▋   | 73/110 [27:35<11:01, 17.89s/it]

{'loss': 1.969, 'learning_rate': 2.0181818181818183e-05, 'epoch': 6.64}


 67%|██████▋   | 74/110 [27:52<10:40, 17.79s/it]

{'loss': 1.8727, 'learning_rate': 1.963636363636364e-05, 'epoch': 6.73}


 68%|██████▊   | 75/110 [28:10<10:19, 17.71s/it]

{'loss': 2.0391, 'learning_rate': 1.909090909090909e-05, 'epoch': 6.82}


 69%|██████▉   | 76/110 [28:27<10:00, 17.65s/it]

{'loss': 2.2392, 'learning_rate': 1.8545454545454545e-05, 'epoch': 6.91}


 70%|███████   | 77/110 [28:36<08:18, 15.10s/it]

{'loss': 1.9646, 'learning_rate': 1.8e-05, 'epoch': 7.0}


 71%|███████   | 78/110 [28:54<08:25, 15.79s/it]

{'loss': 2.0674, 'learning_rate': 1.7454545454545456e-05, 'epoch': 7.09}


 72%|███████▏  | 79/110 [29:11<08:27, 16.36s/it]

{'loss': 1.9589, 'learning_rate': 1.6909090909090907e-05, 'epoch': 7.18}


 73%|███████▎  | 80/110 [29:29<08:20, 16.68s/it]

{'loss': 1.9483, 'learning_rate': 1.6363636363636363e-05, 'epoch': 7.27}


  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
                                                
 73%|███████▎  | 80/110 [31:51<08:20, 16.68s/it]

{'eval_loss': 1.8224740028381348, 'eval_mean_iou': 0.21741457268461725, 'eval_mean_accuracy': 0.9763504834000365, 'eval_overall_accuracy': 0.9808723229778371, 'eval_per_category_iou': [0.9804308114633828, 0.5414711973289379, nan, 0.0, nan, 0.0, nan, nan, nan, nan, 0.0, nan, nan, 0.0, nan, nan, 0.0, nan, nan], 'eval_per_category_accuracy': [0.9810925526526263, 0.9716084141474467, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], 'eval_runtime': 141.8831, 'eval_samples_per_second': 0.366, 'eval_steps_per_second': 0.042, 'epoch': 7.27}


 74%|███████▎  | 81/110 [32:08<28:40, 59.33s/it]

{'loss': 1.8836, 'learning_rate': 1.5818181818181818e-05, 'epoch': 7.36}


 75%|███████▍  | 82/110 [32:25<21:46, 46.65s/it]

{'loss': 1.9192, 'learning_rate': 1.5272727272727273e-05, 'epoch': 7.45}


 75%|███████▌  | 83/110 [32:42<17:02, 37.88s/it]

{'loss': 1.8878, 'learning_rate': 1.4727272727272728e-05, 'epoch': 7.55}


 76%|███████▋  | 84/110 [33:00<13:45, 31.74s/it]

{'loss': 1.8765, 'learning_rate': 1.4181818181818181e-05, 'epoch': 7.64}


 77%|███████▋  | 85/110 [33:17<11:27, 27.48s/it]

{'loss': 1.8812, 'learning_rate': 1.3636363636363637e-05, 'epoch': 7.73}


 78%|███████▊  | 86/110 [33:35<09:47, 24.48s/it]

{'loss': 1.8688, 'learning_rate': 1.309090909090909e-05, 'epoch': 7.82}


 79%|███████▉  | 87/110 [33:52<08:34, 22.38s/it]

{'loss': 1.8859, 'learning_rate': 1.2545454545454545e-05, 'epoch': 7.91}


 80%|████████  | 88/110 [34:01<06:44, 18.38s/it]

{'loss': 1.9683, 'learning_rate': 1.2e-05, 'epoch': 8.0}


 81%|████████  | 89/110 [34:19<06:20, 18.12s/it]

{'loss': 1.8459, 'learning_rate': 1.1454545454545455e-05, 'epoch': 8.09}


 82%|████████▏ | 90/110 [34:36<05:57, 17.89s/it]

{'loss': 2.0552, 'learning_rate': 1.0909090909090909e-05, 'epoch': 8.18}


 83%|████████▎ | 91/110 [34:53<05:35, 17.64s/it]

{'loss': 2.1157, 'learning_rate': 1.0363636363636364e-05, 'epoch': 8.27}


 84%|████████▎ | 92/110 [35:11<05:18, 17.69s/it]

{'loss': 2.0524, 'learning_rate': 9.81818181818182e-06, 'epoch': 8.36}


 85%|████████▍ | 93/110 [35:29<05:00, 17.71s/it]

{'loss': 1.9352, 'learning_rate': 9.272727272727273e-06, 'epoch': 8.45}


 85%|████████▌ | 94/110 [35:47<04:44, 17.77s/it]

{'loss': 1.86, 'learning_rate': 8.727272727272728e-06, 'epoch': 8.55}


 86%|████████▋ | 95/110 [36:04<04:25, 17.69s/it]

{'loss': 1.8077, 'learning_rate': 8.181818181818181e-06, 'epoch': 8.64}


 87%|████████▋ | 96/110 [36:22<04:07, 17.66s/it]

{'loss': 1.9158, 'learning_rate': 7.636363636363636e-06, 'epoch': 8.73}


 88%|████████▊ | 97/110 [36:39<03:48, 17.60s/it]

{'loss': 1.8067, 'learning_rate': 7.090909090909091e-06, 'epoch': 8.82}


 89%|████████▉ | 98/110 [36:57<03:31, 17.61s/it]

{'loss': 1.8853, 'learning_rate': 6.545454545454545e-06, 'epoch': 8.91}


 90%|█████████ | 99/110 [37:06<02:45, 15.05s/it]

{'loss': 1.8781, 'learning_rate': 6e-06, 'epoch': 9.0}


 91%|█████████ | 100/110 [37:23<02:37, 15.76s/it]

{'loss': 1.8573, 'learning_rate': 5.4545454545454545e-06, 'epoch': 9.09}


  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
                                                 
 91%|█████████ | 100/110 [39:44<02:37, 15.76s/it]

{'eval_loss': 1.8133410215377808, 'eval_mean_iou': 0.31451255122000676, 'eval_mean_accuracy': 0.9682472518349763, 'eval_overall_accuracy': 0.9845462211165715, 'eval_per_category_iou': [0.9841971906800393, 0.5883655654199944, nan, 0.0, 0.0, nan, nan, nan, nan, nan, nan, nan, nan, 0.0, nan, nan, nan, nan, nan], 'eval_per_category_accuracy': [0.9853400389074338, 0.9511544647625186, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], 'eval_runtime': 140.6749, 'eval_samples_per_second': 0.37, 'eval_steps_per_second': 0.043, 'epoch': 9.09}


 92%|█████████▏| 101/110 [40:02<08:47, 58.62s/it]

{'loss': 1.8152, 'learning_rate': 4.90909090909091e-06, 'epoch': 9.18}


 93%|█████████▎| 102/110 [40:19<06:10, 46.29s/it]

{'loss': 1.8628, 'learning_rate': 4.363636363636364e-06, 'epoch': 9.27}


 94%|█████████▎| 103/110 [40:37<04:23, 37.65s/it]

{'loss': 2.1175, 'learning_rate': 3.818181818181818e-06, 'epoch': 9.36}


 95%|█████████▍| 104/110 [40:54<03:09, 31.57s/it]

{'loss': 1.9229, 'learning_rate': 3.2727272727272725e-06, 'epoch': 9.45}


 95%|█████████▌| 105/110 [41:11<02:15, 27.18s/it]

{'loss': 1.9264, 'learning_rate': 2.7272727272727272e-06, 'epoch': 9.55}


 96%|█████████▋| 106/110 [41:28<01:36, 24.16s/it]

{'loss': 2.0426, 'learning_rate': 2.181818181818182e-06, 'epoch': 9.64}


 97%|█████████▋| 107/110 [41:46<01:06, 22.21s/it]

{'loss': 1.81, 'learning_rate': 1.6363636363636363e-06, 'epoch': 9.73}


 98%|█████████▊| 108/110 [42:03<00:41, 20.77s/it]

{'loss': 1.8099, 'learning_rate': 1.090909090909091e-06, 'epoch': 9.82}


 99%|█████████▉| 109/110 [42:21<00:19, 19.84s/it]

{'loss': 1.8384, 'learning_rate': 5.454545454545455e-07, 'epoch': 9.91}


100%|██████████| 110/110 [42:30<00:00, 23.19s/it]

{'loss': 2.063, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 2550.7498, 'train_samples_per_second': 0.412, 'train_steps_per_second': 0.043, 'train_loss': 2.189501702785492, 'epoch': 10.0}





TrainOutput(global_step=110, training_loss=2.189501702785492, metrics={'train_runtime': 2550.7498, 'train_samples_per_second': 0.412, 'train_steps_per_second': 0.043, 'train_loss': 2.189501702785492, 'epoch': 10.0})

In [27]:
braod_bean_trainer.state.log_history

[{'loss': 3.0092,
  'learning_rate': 5.945454545454546e-05,
  'epoch': 0.09,
  'step': 1},
 {'loss': 2.9758,
  'learning_rate': 5.890909090909091e-05,
  'epoch': 0.18,
  'step': 2},
 {'loss': 2.9369,
  'learning_rate': 5.836363636363637e-05,
  'epoch': 0.27,
  'step': 3},
 {'loss': 2.9338,
  'learning_rate': 5.781818181818182e-05,
  'epoch': 0.36,
  'step': 4},
 {'loss': 2.8678,
  'learning_rate': 5.7272727272727274e-05,
  'epoch': 0.45,
  'step': 5},
 {'loss': 2.8533,
  'learning_rate': 5.6727272727272726e-05,
  'epoch': 0.55,
  'step': 6},
 {'loss': 2.8616,
  'learning_rate': 5.6181818181818184e-05,
  'epoch': 0.64,
  'step': 7},
 {'loss': 2.8397,
  'learning_rate': 5.5636363636363636e-05,
  'epoch': 0.73,
  'step': 8},
 {'loss': 2.7624,
  'learning_rate': 5.5090909090909094e-05,
  'epoch': 0.82,
  'step': 9},
 {'loss': 2.7696,
  'learning_rate': 5.4545454545454546e-05,
  'epoch': 0.91,
  'step': 10},
 {'loss': 2.7335,
  'learning_rate': 5.4000000000000005e-05,
  'epoch': 1.0,
  'ste

In [1]:
braod_bean_trainer.save_model('models/broad_bean')
# broad_bean_test_ds

NameError: name 'braod_bean_trainer' is not defined

In [20]:
# Interpret results
# Add method to rewrite labels in the annotions