In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [2]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [3]:
import torch
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('CUDA version:', torch.version.cuda)
    print('Memory Usage:') 
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

NVIDIA GeForce GTX 960
CUDA version: 11.7
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [4]:
import numpy as np
import PIL
from PIL import Image
import datasets
import evaluate
import torch
import json
import codecs
import os
from os import sys

from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation, TrainingArguments, Trainer, EarlyStoppingCallback

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
module_path = os.path.abspath(os.path.join('./src'))
if module_path not in sys.path:
    sys.path.append(module_path)
from data_prepossessing import create_datasets_for_plants, get_labels
from constants import *

In [6]:
checkpoint = "nvidia/mit-b0"
image_processor = SegformerImageProcessor.from_pretrained(checkpoint)
image_processor



SegformerImageProcessor {
  "do_normalize": true,
  "do_reduce_labels": false,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "SegformerFeatureExtractor",
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "SegformerImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 512,
    "width": 512
  }
}

In [7]:
def train_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels, return_tensors="pt")
    return inputs

In [8]:
metric = evaluate.load("mean_iou")

In [9]:
def compute_metrics(num_labels, eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = torch.nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

In [10]:
training_args = TrainingArguments(
    output_dir="segformer-b0-scene-parse-150",
    learning_rate=6e-5,
    num_train_epochs=50,
    per_device_train_batch_size=6,
    per_device_eval_batch_size=6,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=100,
    eval_steps=100,
    logging_steps=1,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    seed=seed,
)

In [11]:
def initialize_trainer(model, num_labels, train_ds, test_ds) :
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        compute_metrics=lambda eval_pred: compute_metrics(num_labels, eval_pred),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

    return trainer

In [12]:
def model_init():
	model_type = "multiclass"
	crop = "broad_bean" 
	id2label, label2id = get_labels(crop, model_type)
	return AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)

In [13]:
import optuna
optuna.__version__
import sigopt
# import wandb
import ray
from ray import tune

def sigopt_hp_space(trial):
    return [
        {"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double"}
    ]

def optuna_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True)
    }

# def wandb_hp_space(trial):
#     return {
#         "method": "random",
#         "metric": {"name": "objective", "goal": "minimize"},
#         "parameters": {
#             "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4}
#         },
#     }

def ray_hp_space(trial):
    return {
        "learning_rate": tune.loguniform(1e-6, 1e-4)
    }

2023-09-09 16:40:17,862	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2023-09-09 16:40:17,930	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [14]:
def initialize_trainer_for_hp_search(num_labels, train_ds, test_ds) :
    trainer = Trainer(
        model_init=model_init,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        compute_metrics=lambda eval_pred: compute_metrics(num_labels, eval_pred),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

    best_trial = trainer.hyperparameter_search(
        direction="minimize",
        backend="optuna",
        hp_space=optuna_hp_space,
        n_trials=20,
    )

    return best_trial

In [15]:
model_type = "multiclass"
crop = "broad_bean" 
model_plant_names = [crop] + weed_plants
train_ds, val_ds, test_ds = create_datasets_for_plants(model_plant_names, model_type, crop)

train_ds.set_transform(train_transforms)
val_ds.set_transform(train_transforms)
test_ds.set_transform(train_transforms)

id2label, label2id = get_labels(crop, model_type)

# model = model_init(id2label, label2id)
best_trial = initialize_trainer_for_hp_search(len(id2label), train_ds, val_ds)
best_trial
# trainer.train()

['img_00173.png', 'img_00174.png', 'img_00175.png', 'img_00176.png', 'img_00177.png', 'img_00178.png', 'img_00672.png', 'img_00673.png', 'img_00674.png', 'img_00675.png', 'img_00676.png', 'img_00677.png', 'img_00678.png', 'img_00679.png', 'img_00680.png', 'img_00681.png', 'img_00682.png', 'img_00683.png', 'img_00684.png', 'img_00882.png', 'img_00883.png', 'img_00884.png', 'img_00885.png', 'img_00886.png', 'img_00887.png', 'img_00938.png', 'img_00980.png', 'img_00981.png', 'img_00982.png', 'img_00983.png', 'img_00984.png', 'img_00985.png', 'img_00986.png', 'img_00987.png', 'img_00988.png', 'img_00989.png', 'img_01070.png', 'img_01071.png', 'img_01072.png', 'img_01073.png', 'img_01074.png', 'img_01075.png', 'img_01076.png', 'img_01077.png', 'img_01078.png', 'img_01079.png', 'img_01219.png', 'img_01220.png', 'img_01221.png', 'img_01222.png', 'img_01223.png', 'img_01224.png', 'img_01225.png', 'img_01226.png', 'img_01227.png', 'img_01228.png', 'img_01279.png', 'img_01280.png', 'img_01281.pn

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.3.proj.bias', 'decode_head.batch_norm.running_var', 'decode_head.linear_c.1.proj.bias', 'decode_head.batch_norm.running_mean', 'decode_head.linear_fuse.weight', 'decode_head.batch_norm.weight', 'decode_head.batch_norm.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.classifier.weight', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.classifier.bias', 'decode_head.linear_c.0.proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2023-09-09 16:40:19,293] A new study created in memory with name: no-name-6163023c-3647-4489-a47c-99c00c867b56
Some weights of SegformerForSemanticSegmentation were

{'loss': 2.7285, 'learning_rate': 1.023820073388892e-05, 'epoch': 0.01}


  0%|          | 2/3900 [00:06<3:22:57,  3.12s/it]

{'loss': 2.7358, 'learning_rate': 1.0235574880917929e-05, 'epoch': 0.03}


  0%|          | 3/3900 [00:09<3:06:22,  2.87s/it]

{'loss': 2.736, 'learning_rate': 1.0232949027946941e-05, 'epoch': 0.04}


  0%|          | 4/3900 [00:11<2:58:27,  2.75s/it]

{'loss': 2.731, 'learning_rate': 1.023032317497595e-05, 'epoch': 0.05}


  0%|          | 5/3900 [00:14<2:54:14,  2.68s/it]

{'loss': 2.7336, 'learning_rate': 1.022769732200496e-05, 'epoch': 0.06}


  0%|          | 6/3900 [00:16<2:51:57,  2.65s/it]

{'loss': 2.7156, 'learning_rate': 1.0225071469033971e-05, 'epoch': 0.08}


  0%|          | 7/3900 [00:19<2:50:54,  2.63s/it]

{'loss': 2.7023, 'learning_rate': 1.0222445616062982e-05, 'epoch': 0.09}


  0%|          | 8/3900 [00:21<2:49:31,  2.61s/it]

{'loss': 2.7067, 'learning_rate': 1.0219819763091993e-05, 'epoch': 0.1}


  0%|          | 9/3900 [00:24<2:48:25,  2.60s/it]

{'loss': 2.7051, 'learning_rate': 1.0217193910121002e-05, 'epoch': 0.12}


  0%|          | 10/3900 [00:27<2:47:48,  2.59s/it]

{'loss': 2.6948, 'learning_rate': 1.0214568057150012e-05, 'epoch': 0.13}


  0%|          | 11/3900 [00:29<2:46:59,  2.58s/it]

{'loss': 2.7044, 'learning_rate': 1.0211942204179023e-05, 'epoch': 0.14}


  0%|          | 12/3900 [00:32<2:46:45,  2.57s/it]

{'loss': 2.6609, 'learning_rate': 1.0209316351208033e-05, 'epoch': 0.15}


  0%|          | 13/3900 [00:34<2:46:30,  2.57s/it]

{'loss': 2.6824, 'learning_rate': 1.0206690498237044e-05, 'epoch': 0.17}


  0%|          | 14/3900 [00:37<2:46:08,  2.57s/it]

{'loss': 2.6774, 'learning_rate': 1.0204064645266053e-05, 'epoch': 0.18}


  0%|          | 15/3900 [00:39<2:46:10,  2.57s/it]

{'loss': 2.6467, 'learning_rate': 1.0201438792295063e-05, 'epoch': 0.19}


  0%|          | 16/3900 [00:42<2:46:00,  2.56s/it]

{'loss': 2.6896, 'learning_rate': 1.0198812939324074e-05, 'epoch': 0.21}


  0%|          | 17/3900 [00:44<2:46:05,  2.57s/it]

{'loss': 2.6328, 'learning_rate': 1.0196187086353085e-05, 'epoch': 0.22}


  0%|          | 18/3900 [00:47<2:46:21,  2.57s/it]

{'loss': 2.6514, 'learning_rate': 1.0193561233382095e-05, 'epoch': 0.23}


  0%|          | 19/3900 [00:50<2:46:34,  2.58s/it]

{'loss': 2.6619, 'learning_rate': 1.0190935380411104e-05, 'epoch': 0.24}


  1%|          | 20/3900 [00:52<2:46:21,  2.57s/it]

{'loss': 2.6572, 'learning_rate': 1.0188309527440115e-05, 'epoch': 0.26}


  1%|          | 21/3900 [00:55<2:46:10,  2.57s/it]

{'loss': 2.6547, 'learning_rate': 1.0185683674469125e-05, 'epoch': 0.27}


  1%|          | 22/3900 [00:57<2:45:53,  2.57s/it]

{'loss': 2.6441, 'learning_rate': 1.0183057821498136e-05, 'epoch': 0.28}


  1%|          | 23/3900 [01:00<2:45:40,  2.56s/it]

{'loss': 2.6462, 'learning_rate': 1.0180431968527147e-05, 'epoch': 0.29}


  1%|          | 24/3900 [01:02<2:45:31,  2.56s/it]

{'loss': 2.6591, 'learning_rate': 1.0177806115556156e-05, 'epoch': 0.31}


  1%|          | 25/3900 [01:05<2:45:20,  2.56s/it]

{'loss': 2.6323, 'learning_rate': 1.0175180262585166e-05, 'epoch': 0.32}


  1%|          | 26/3900 [01:08<2:45:31,  2.56s/it]

{'loss': 2.6265, 'learning_rate': 1.0172554409614177e-05, 'epoch': 0.33}


  1%|          | 27/3900 [01:10<2:45:29,  2.56s/it]

{'loss': 2.5967, 'learning_rate': 1.0169928556643187e-05, 'epoch': 0.35}


  1%|          | 28/3900 [01:13<2:45:19,  2.56s/it]

{'loss': 2.6272, 'learning_rate': 1.0167302703672198e-05, 'epoch': 0.36}


  1%|          | 29/3900 [01:15<2:45:14,  2.56s/it]

{'loss': 2.6197, 'learning_rate': 1.0164676850701207e-05, 'epoch': 0.37}


  1%|          | 30/3900 [01:18<2:45:19,  2.56s/it]

{'loss': 2.5897, 'learning_rate': 1.0162050997730217e-05, 'epoch': 0.38}


  1%|          | 31/3900 [01:20<2:45:27,  2.57s/it]

{'loss': 2.6227, 'learning_rate': 1.0159425144759228e-05, 'epoch': 0.4}


  1%|          | 32/3900 [01:23<2:45:39,  2.57s/it]

{'loss': 2.6073, 'learning_rate': 1.0156799291788239e-05, 'epoch': 0.41}


  1%|          | 33/3900 [01:26<2:45:32,  2.57s/it]

{'loss': 2.5825, 'learning_rate': 1.015417343881725e-05, 'epoch': 0.42}


  1%|          | 34/3900 [01:28<2:45:19,  2.57s/it]

{'loss': 2.5822, 'learning_rate': 1.0151547585846258e-05, 'epoch': 0.44}


  1%|          | 35/3900 [01:31<2:45:07,  2.56s/it]

{'loss': 2.5767, 'learning_rate': 1.0148921732875269e-05, 'epoch': 0.45}


  1%|          | 36/3900 [01:33<2:45:07,  2.56s/it]

{'loss': 2.55, 'learning_rate': 1.014629587990428e-05, 'epoch': 0.46}


  1%|          | 37/3900 [01:36<2:44:50,  2.56s/it]

{'loss': 2.6104, 'learning_rate': 1.014367002693329e-05, 'epoch': 0.47}


  1%|          | 38/3900 [01:38<2:44:40,  2.56s/it]

{'loss': 2.546, 'learning_rate': 1.0141044173962299e-05, 'epoch': 0.49}


  1%|          | 39/3900 [01:41<2:44:34,  2.56s/it]

{'loss': 2.5984, 'learning_rate': 1.013841832099131e-05, 'epoch': 0.5}


  1%|          | 40/3900 [01:43<2:44:25,  2.56s/it]

{'loss': 2.5944, 'learning_rate': 1.0135792468020322e-05, 'epoch': 0.51}


  1%|          | 41/3900 [01:46<2:44:16,  2.55s/it]

{'loss': 2.5403, 'learning_rate': 1.013316661504933e-05, 'epoch': 0.53}


  1%|          | 42/3900 [01:49<2:44:38,  2.56s/it]

{'loss': 2.4914, 'learning_rate': 1.0130540762078341e-05, 'epoch': 0.54}


  1%|          | 43/3900 [01:51<2:44:53,  2.57s/it]

{'loss': 2.5506, 'learning_rate': 1.012791490910735e-05, 'epoch': 0.55}


  1%|          | 44/3900 [01:54<2:44:32,  2.56s/it]

{'loss': 2.5951, 'learning_rate': 1.0125289056136363e-05, 'epoch': 0.56}


  1%|          | 45/3900 [01:56<2:44:42,  2.56s/it]

{'loss': 2.5624, 'learning_rate': 1.0122663203165373e-05, 'epoch': 0.58}


  1%|          | 46/3900 [01:59<2:44:34,  2.56s/it]

{'loss': 2.5015, 'learning_rate': 1.0120037350194382e-05, 'epoch': 0.59}


  1%|          | 47/3900 [02:01<2:44:19,  2.56s/it]

{'loss': 2.5225, 'learning_rate': 1.0117411497223393e-05, 'epoch': 0.6}


  1%|          | 48/3900 [02:04<2:44:23,  2.56s/it]

{'loss': 2.518, 'learning_rate': 1.0114785644252403e-05, 'epoch': 0.62}


  1%|▏         | 49/3900 [02:07<2:44:15,  2.56s/it]

{'loss': 2.4704, 'learning_rate': 1.0112159791281414e-05, 'epoch': 0.63}


  1%|▏         | 50/3900 [02:09<2:44:15,  2.56s/it]

{'loss': 2.4846, 'learning_rate': 1.0109533938310424e-05, 'epoch': 0.64}


  1%|▏         | 51/3900 [02:12<2:44:08,  2.56s/it]

{'loss': 2.4492, 'learning_rate': 1.0106908085339433e-05, 'epoch': 0.65}


  1%|▏         | 52/3900 [02:14<2:44:14,  2.56s/it]

{'loss': 2.5109, 'learning_rate': 1.0104282232368444e-05, 'epoch': 0.67}


  1%|▏         | 53/3900 [02:17<2:44:01,  2.56s/it]

{'loss': 2.4915, 'learning_rate': 1.0101656379397455e-05, 'epoch': 0.68}


  1%|▏         | 54/3900 [02:19<2:44:04,  2.56s/it]

{'loss': 2.458, 'learning_rate': 1.0099030526426465e-05, 'epoch': 0.69}


  1%|▏         | 55/3900 [02:22<2:43:49,  2.56s/it]

{'loss': 2.4828, 'learning_rate': 1.0096404673455476e-05, 'epoch': 0.71}


  1%|▏         | 56/3900 [02:24<2:44:02,  2.56s/it]

{'loss': 2.4695, 'learning_rate': 1.0093778820484485e-05, 'epoch': 0.72}


  1%|▏         | 57/3900 [02:27<2:44:04,  2.56s/it]

{'loss': 2.4987, 'learning_rate': 1.0091152967513495e-05, 'epoch': 0.73}


  1%|▏         | 58/3900 [02:30<2:44:07,  2.56s/it]

{'loss': 2.4364, 'learning_rate': 1.0088527114542506e-05, 'epoch': 0.74}


  2%|▏         | 59/3900 [02:32<2:43:48,  2.56s/it]

{'loss': 2.4429, 'learning_rate': 1.0085901261571517e-05, 'epoch': 0.76}


  2%|▏         | 60/3900 [02:35<2:44:02,  2.56s/it]

{'loss': 2.4619, 'learning_rate': 1.0083275408600527e-05, 'epoch': 0.77}


  2%|▏         | 61/3900 [02:37<2:44:04,  2.56s/it]

{'loss': 2.4738, 'learning_rate': 1.0080649555629536e-05, 'epoch': 0.78}


  2%|▏         | 62/3900 [02:40<2:44:21,  2.57s/it]

{'loss': 2.4939, 'learning_rate': 1.0078023702658547e-05, 'epoch': 0.79}


  2%|▏         | 63/3900 [02:42<2:44:08,  2.57s/it]

{'loss': 2.4424, 'learning_rate': 1.0075397849687557e-05, 'epoch': 0.81}


  2%|▏         | 64/3900 [02:45<2:44:29,  2.57s/it]

{'loss': 2.4398, 'learning_rate': 1.0072771996716568e-05, 'epoch': 0.82}


  2%|▏         | 65/3900 [02:48<2:44:30,  2.57s/it]

{'loss': 2.5395, 'learning_rate': 1.0070146143745577e-05, 'epoch': 0.83}


  2%|▏         | 66/3900 [02:50<2:44:07,  2.57s/it]

{'loss': 2.4045, 'learning_rate': 1.0067520290774587e-05, 'epoch': 0.85}


  2%|▏         | 67/3900 [02:53<2:43:54,  2.57s/it]

{'loss': 2.3623, 'learning_rate': 1.0064894437803598e-05, 'epoch': 0.86}


  2%|▏         | 68/3900 [02:55<2:44:08,  2.57s/it]

{'loss': 2.4309, 'learning_rate': 1.0062268584832609e-05, 'epoch': 0.87}


  2%|▏         | 69/3900 [02:58<2:44:15,  2.57s/it]

{'loss': 2.3988, 'learning_rate': 1.005964273186162e-05, 'epoch': 0.88}


  2%|▏         | 70/3900 [03:00<2:44:05,  2.57s/it]

{'loss': 2.4248, 'learning_rate': 1.0057016878890628e-05, 'epoch': 0.9}


  2%|▏         | 71/3900 [03:03<2:43:58,  2.57s/it]

{'loss': 2.4082, 'learning_rate': 1.0054391025919639e-05, 'epoch': 0.91}


  2%|▏         | 72/3900 [03:06<2:44:17,  2.58s/it]

{'loss': 2.3644, 'learning_rate': 1.005176517294865e-05, 'epoch': 0.92}


  2%|▏         | 73/3900 [03:08<2:44:17,  2.58s/it]

{'loss': 2.5203, 'learning_rate': 1.004913931997766e-05, 'epoch': 0.94}


  2%|▏         | 74/3900 [03:11<2:44:09,  2.57s/it]

{'loss': 2.34, 'learning_rate': 1.004651346700667e-05, 'epoch': 0.95}


  2%|▏         | 75/3900 [03:13<2:43:55,  2.57s/it]

{'loss': 2.3698, 'learning_rate': 1.004388761403568e-05, 'epoch': 0.96}


  2%|▏         | 76/3900 [03:16<2:43:50,  2.57s/it]

{'loss': 2.3657, 'learning_rate': 1.004126176106469e-05, 'epoch': 0.97}


  2%|▏         | 77/3900 [03:18<2:41:48,  2.54s/it]

{'loss': 2.3717, 'learning_rate': 1.0038635908093702e-05, 'epoch': 0.99}


  2%|▏         | 78/3900 [03:20<2:23:51,  2.26s/it]

{'loss': 2.4084, 'learning_rate': 1.0036010055122711e-05, 'epoch': 1.0}


  2%|▏         | 79/3900 [03:23<2:42:17,  2.55s/it]

{'loss': 2.3894, 'learning_rate': 1.0033384202151722e-05, 'epoch': 1.01}


  2%|▏         | 80/3900 [03:26<2:42:36,  2.55s/it]

{'loss': 2.4008, 'learning_rate': 1.003075834918073e-05, 'epoch': 1.03}


  2%|▏         | 81/3900 [03:28<2:42:44,  2.56s/it]

{'loss': 2.3322, 'learning_rate': 1.0028132496209743e-05, 'epoch': 1.04}


  2%|▏         | 82/3900 [03:31<2:42:44,  2.56s/it]

{'loss': 2.4394, 'learning_rate': 1.0025506643238754e-05, 'epoch': 1.05}


  2%|▏         | 83/3900 [03:33<2:42:50,  2.56s/it]

{'loss': 2.3533, 'learning_rate': 1.0022880790267763e-05, 'epoch': 1.06}


  2%|▏         | 84/3900 [03:36<2:42:54,  2.56s/it]

{'loss': 2.4142, 'learning_rate': 1.0020254937296773e-05, 'epoch': 1.08}


  2%|▏         | 85/3900 [03:38<2:42:41,  2.56s/it]

{'loss': 2.34, 'learning_rate': 1.0017629084325782e-05, 'epoch': 1.09}


  2%|▏         | 86/3900 [03:41<2:42:21,  2.55s/it]

{'loss': 2.2909, 'learning_rate': 1.0015003231354794e-05, 'epoch': 1.1}


  2%|▏         | 87/3900 [03:44<2:42:36,  2.56s/it]

{'loss': 2.3237, 'learning_rate': 1.0012377378383803e-05, 'epoch': 1.12}


  2%|▏         | 88/3900 [03:46<2:42:21,  2.56s/it]

{'loss': 2.2916, 'learning_rate': 1.0009751525412814e-05, 'epoch': 1.13}


  2%|▏         | 89/3900 [03:49<2:42:14,  2.55s/it]

{'loss': 2.3121, 'learning_rate': 1.0007125672441824e-05, 'epoch': 1.14}


  2%|▏         | 90/3900 [03:51<2:42:05,  2.55s/it]

{'loss': 2.2428, 'learning_rate': 1.0004499819470835e-05, 'epoch': 1.15}


  2%|▏         | 91/3900 [03:54<2:42:03,  2.55s/it]

{'loss': 2.2967, 'learning_rate': 1.0001873966499846e-05, 'epoch': 1.17}


  2%|▏         | 92/3900 [03:56<2:42:18,  2.56s/it]

{'loss': 2.2753, 'learning_rate': 9.999248113528855e-06, 'epoch': 1.18}


  2%|▏         | 93/3900 [03:59<2:42:20,  2.56s/it]

{'loss': 2.2785, 'learning_rate': 9.996622260557865e-06, 'epoch': 1.19}


  2%|▏         | 94/3900 [04:01<2:42:10,  2.56s/it]

{'loss': 2.3265, 'learning_rate': 9.993996407586876e-06, 'epoch': 1.21}


  2%|▏         | 95/3900 [04:04<2:42:32,  2.56s/it]

{'loss': 2.2555, 'learning_rate': 9.991370554615886e-06, 'epoch': 1.22}


  2%|▏         | 96/3900 [04:07<2:42:26,  2.56s/it]

{'loss': 2.3067, 'learning_rate': 9.988744701644897e-06, 'epoch': 1.23}


  2%|▏         | 97/3900 [04:09<2:42:28,  2.56s/it]

{'loss': 2.2787, 'learning_rate': 9.986118848673906e-06, 'epoch': 1.24}


  3%|▎         | 98/3900 [04:12<2:42:17,  2.56s/it]

{'loss': 2.268, 'learning_rate': 9.983492995702917e-06, 'epoch': 1.26}


  3%|▎         | 99/3900 [04:14<2:42:33,  2.57s/it]

{'loss': 2.2498, 'learning_rate': 9.980867142731927e-06, 'epoch': 1.27}


  3%|▎         | 100/3900 [04:17<2:42:07,  2.56s/it]

{'loss': 2.2782, 'learning_rate': 9.978241289760938e-06, 'epoch': 1.28}


[W 2023-09-09 16:45:52,829] Trial 0 failed with parameters: {'learning_rate': 1.024082658685991e-05} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/kate/miniconda3/envs/master/lib/python3.11/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/home/kate/miniconda3/envs/master/lib/python3.11/site-packages/transformers/integrations/integration_utils.py", line 199, in _objective
    trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  File "/home/kate/miniconda3/envs/master/lib/python3.11/site-packages/transformers/trainer.py", line 1553, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/kate/miniconda3/envs/master/lib/python3.11/site-packages/transformers/trainer.py", line 1927, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/ho

KeyboardInterrupt: 

In [13]:
def train_model_of_type_for_crop(model_type, crop):
    model_plant_names = [crop] + weed_plants
    train_ds, val_ds, test_ds = create_datasets_for_plants(model_plant_names, model_type, crop)

    print("Training subset number of images: " + str(train_ds.num_rows))
    print("Validation subset number of images: " + str(val_ds.num_rows))
    print("Test subset number of images: " + str(test_ds.num_rows))

    train_ds.set_transform(train_transforms)
    val_ds.set_transform(train_transforms)
    test_ds.set_transform(train_transforms)

    id2label, label2id = get_labels(crop, model_type)

    print('Number of classes:', len(id2label))
    print('id2label:', id2label)
    print('label2id:', label2id)

    model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)
    trainer = initialize_trainer(model, len(id2label), train_ds, val_ds)
    trainer.train()

    # Save the trained model, so that it can be used for inference later.
    # Save the log history, so that it can be used for plotting later.
    trainer.save_model('models/' + model_type + '/' + crop)
    with open('models/' + model_type + '/' + crop + '/log_history.json', 'w') as file:
        log_history = trainer.state.log_history
        json.dump(log_history, file)

    test_metric = trainer.evaluate(test_ds)
    test_metric

    with open('models/' + model_type + '/' + crop + '/test_metric.json', 'w') as file:
        json.dump(test_metric, file)

In [14]:
train_model_of_type_for_crop("multiclass", "broad_bean")
# train_model_of_type_for_crop("multiclass", "sugar_beet")

['img_00071.png', 'img_00072.png', 'img_00074.png', 'img_00083.png', 'img_00084.png', 'img_00085.png', 'img_00151.png', 'img_00152.png', 'img_00153.png', 'img_00154.png', 'img_00155.png', 'img_00156.png', 'img_00210.png', 'img_00211.png', 'img_00212.png', 'img_00213.png', 'img_00214.png', 'img_00215.png', 'img_00216.png', 'img_00217.png', 'img_00218.png', 'img_00220.png', 'img_00221.png', 'img_00226.png', 'img_00227.png', 'img_00229.png', 'img_00248.png', 'img_00249.png', 'img_00250.png', 'img_00251.png', 'img_00254.png', 'img_00255.png', 'img_00256.png', 'img_00257.png', 'img_00258.png', 'img_00259.png', 'img_00260.png', 'img_00264.png', 'img_00265.png', 'img_00266.png', 'img_00267.png', 'img_00268.png', 'img_00269.png', 'img_00271.png', 'img_00293.png', 'img_00294.png', 'img_00295.png', 'img_00296.png', 'img_00297.png', 'img_00298.png', 'img_00299.png', 'img_00300.png', 'img_00301.png', 'img_00302.png', 'img_00303.png', 'img_00379.png', 'img_00380.png', 'img_00381.png', 'img_00382.pn

['img_01804.png', 'img_01807.png', 'img_01790.png', 'img_01803.png', 'img_01793.png', 'img_01797.png', 'img_01798.png', 'img_01796.png', 'img_01799.png', 'img_01795.png', 'img_01791.png', 'img_01802.png', 'img_01806.png', 'img_01789.png', 'img_01805.png', 'img_01800.png', 'img_01794.png', 'img_01801.png', 'img_01792.png', 'img_01788.png']
Number of plant images for plant corn_spurry : 20
['img_01768.png', 'img_01767.png', 'img_01796.png']
Number of plant images for plant red-root_amaranth : 3
['img_01812.png', 'img_01809.png', 'img_01815.png', 'img_01810.png', 'img_01808.png', 'img_01813.png', 'img_01814.png', 'img_01811.png']
Number of plant images for plant red_fingergrass : 8
['img_00434.png', 'img_00203.png', 'img_00424.png', 'img_00717.png', 'img_00657.png', 'img_00430.png', 'img_00715.png', 'img_00420.png', 'img_00658.png', 'img_00209.png', 'img_00659.png', 'img_00419.png', 'img_00709.png', 'img_00713.png', 'img_00206.png', 'img_00429.png', 'img_00712.png', 'img_00421.png', 'img_

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.linear_c.2.proj.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_var', 'decode_head.classifier.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.classifier.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.batch_norm.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.batch_norm.running_mean', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.3.proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RuntimeError: To use hyperparameter search, you need to pass your model through a model_init function.

In [None]:
# import subprocess
# from typing import NoReturn

# def shutdown_windows() -> NoReturn:
#     subprocess.run(["shutdown", "/s", "/t", "0"])

# shutdown_windows()