In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [2]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [3]:
import torch
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('CUDA version:', torch.version.cuda)
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

NVIDIA GeForce GTX 960
CUDA version: 11.7
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [4]:
import numpy as np
import PIL
from PIL import Image
import datasets
import os
import evaluate
import torch
import json
import codecs

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
seed = 9876

image_folder = './WE3DS/images/'
annotation_folder = './WE3DS/annotations/'
annotations_aggregated_folder = './WE3DS/annotations_aggregated/'
# Define the paths to the images and annotations
all_image_names = np.array(os.listdir(image_folder))

In [6]:
plant_classification = {
    'void': 'void',
    'soil': 'soil',
    'broad bean': 'crop',
    'corn spurry': 'weed',
    'red-root amaranth': 'weed',
    'common buckwheat': 'crop',
    'pea': 'crop',
    'red fingergrass': 'weed',
    'common wild oat': 'weed',
    'cornflower': 'weed',
    'corn cockle': 'weed',
    'corn': 'crop',
    'milk thistle': 'weed',
    'rye brome': 'weed',
    'soybean': 'crop',
    'sunflower': 'crop',
    'narrow-leaved plantain': 'weed',
    'small-flower geranium': 'weed',
    'sugar beet': 'crop'
}

crop_indices = [index for index, value in enumerate(plant_classification) if plant_classification[value] == 'crop']
weed_indices = [index for index, value in enumerate(plant_classification) if plant_classification[value] == 'weed']

print("Crop indicies: ", crop_indices)
print("Weed indicies: ", weed_indices)

Crop indicies:  [2, 5, 6, 11, 14, 15, 18]
Weed indicies:  [3, 4, 7, 8, 9, 10, 12, 13, 16, 17]


In [7]:
# This is a configuration that will determine which type of model and for which plants the following code will be executed
# weed_plants = [plant_name.replace(" ", "_") for plant_name, classification in plant_classification.items() if classification == 'weed']
weed_plants = []

crop = 'broad_bean'
# crop = 'common_buckwheat'
# crop = 'pea'
# crop = 'corn'
# crop = 'soybean'
# crop = 'sunflower'
# crop = 'sugar_beet'

model_type = 'multiclass'
# model_type = 'binary'

model_plant_names = [crop] + weed_plants
print(model_plant_names)

['broad_bean']


In [8]:
def get_image_meta_filepath(plant_name):
    suffix = '_images.json'
    if plant_name in weed_plants:
        suffix = '_no_crop_images.json'
    return './meta/' + plant_name + suffix

In [9]:
def get_image_list_for_plant(plant_name, model_type):
    # Create an empty list to store the dataset
    image_list = []
    plant_image_names = json.load(codecs.open(get_image_meta_filepath(plant_name), 'r', 'utf-8-sig'))

    # Exclude images that contain more than one crop
    image_names_to_exclude = ['img_01096.png' 'img_01098.png']
    plant_image_names = [image_name for image_name in plant_image_names if image_name not in image_names_to_exclude]
    print(plant_image_names)

    # Iterate over the image and annotation paths
    for image_name in plant_image_names:
        # Load the image and annotation using PIL
        image = Image.open(image_folder + image_name)
        path = None
        if model_type == 'multiclass':
            path = 'WE3DS/annotations_multiclass/' + crop + '/' + image_name
        elif model_type == 'binary':
            path = 'WE3DS/annotations_binary/' + crop + '/' + image_name

        annotation = Image.open(path)
        
        # Create a dictionary entry for the dataseta
        entry = {'image': image, 'annotation': annotation}
        
        # Add the entry to the dataset
        image_list.append(entry)

    return image_list

In [10]:
def create_and_split_dataset_for_plant(plant_image_list):
    dataset = datasets.Dataset.from_list(plant_image_list)
    dataset = dataset.train_test_split(test_size=0.5, seed=seed)
    train_ds = dataset["train"]
    val_ds, test_ds = dataset["test"].train_test_split(test_size=0.5, seed=seed).values()
    return train_ds, val_ds, test_ds

In [11]:
def create_datasets_for_plants(plant_names, model_type):
    p0_image_list = get_image_list_for_plant(plant_names[0], model_type)
    print("Number of plant images for plant", plant_names[0], ":", len(p0_image_list))
    train_ds, val_ds, test_ds = create_and_split_dataset_for_plant(p0_image_list)

    for plant_name in plant_names[1:]:
        p_image_list = get_image_list_for_plant(plant_name, model_type)
        print("Number of plant images for plant", plant_name, ":", len(p_image_list))
        p_train_ds, p_val_ds, p_test_ds = create_and_split_dataset_for_plant(p_image_list)

        train_ds = datasets.concatenate_datasets([train_ds, p_train_ds])
        val_ds = datasets.concatenate_datasets([val_ds, p_val_ds])
        test_ds = datasets.concatenate_datasets([test_ds, p_test_ds])

    return train_ds, val_ds, test_ds

In [12]:
train_ds, val_ds, test_ds = create_datasets_for_plants(model_plant_names, model_type)

['img_00173.png', 'img_00174.png', 'img_00175.png', 'img_00176.png', 'img_00177.png', 'img_00178.png', 'img_00672.png', 'img_00673.png', 'img_00674.png', 'img_00675.png', 'img_00676.png', 'img_00677.png', 'img_00678.png', 'img_00679.png', 'img_00680.png', 'img_00681.png', 'img_00682.png', 'img_00683.png', 'img_00684.png', 'img_00882.png', 'img_00883.png', 'img_00884.png', 'img_00885.png', 'img_00886.png', 'img_00887.png', 'img_00938.png', 'img_00980.png', 'img_00981.png', 'img_00982.png', 'img_00983.png', 'img_00984.png', 'img_00985.png', 'img_00986.png', 'img_00987.png', 'img_00988.png', 'img_00989.png', 'img_01070.png', 'img_01071.png', 'img_01072.png', 'img_01073.png', 'img_01074.png', 'img_01075.png', 'img_01076.png', 'img_01077.png', 'img_01078.png', 'img_01079.png', 'img_01219.png', 'img_01220.png', 'img_01221.png', 'img_01222.png', 'img_01223.png', 'img_01224.png', 'img_01225.png', 'img_01226.png', 'img_01227.png', 'img_01228.png', 'img_01279.png', 'img_01280.png', 'img_01281.pn

In [13]:
print("Training subset number of images: " + str(train_ds.num_rows))
print("Validation subset number of images: " + str(val_ds.num_rows))
print("Test subset number of images: " + str(test_ds.num_rows))

Training subset number of images: 105
Validation subset number of images: 52
Test subset number of images: 53


In [14]:
from transformers import AutoImageProcessor

checkpoint = "nvidia/mit-b0"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [15]:
def train_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels)
    return inputs

In [16]:
metric = evaluate.load("mean_iou")

In [17]:
labels = ['void', 'soil', crop]

if model_type == 'binary':
    labels.append('weeds')
elif model_type == 'multiclass':
    for weed_plant in weed_plants:
        labels.append(weed_plant)

ids = list(range(0, len(labels)))

id2label = dict(zip(ids, labels))
label2id = dict(zip(labels, ids))

num_labels = len(labels)

print('labels:', labels)
print('ids:', ids)
print('num_labels:', num_labels)
print('id2label:', id2label)
print('label2id:', label2id)

labels: ['void', 'soil', 'broad_bean']
ids: [0, 1, 2]
num_labels: 3
id2label: {0: 'void', 1: 'soil', 2: 'broad_bean'}
label2id: {'void': 0, 'soil': 1, 'broad_bean': 2}


In [18]:
def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = torch.nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

In [19]:
from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer, EarlyStoppingCallback

model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)
model

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.0.proj.bias', 'decode_head.batch_norm.weight', 'decode_head.batch_norm.running_var', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.classifier.bias', 'decode_head.linear_fuse.weight', 'decode_head.batch_norm.bias', 'decode_head.batch_norm.running_mean', 'decode_head.classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

In [20]:
# TODO try to remove eval_accumulation_steps
training_args = TrainingArguments(
    output_dir="segformer-b0-scene-parse-150",
    learning_rate=6e-5,
    num_train_epochs=100,
    per_device_train_batch_size=5,
    per_device_eval_batch_size=5,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    seed=seed,
)

In [21]:
def initialize_trainer(train_ds, test_ds) :
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    return trainer

In [22]:
train_ds.set_transform(train_transforms)
val_ds.set_transform(train_transforms)
test_ds.set_transform(train_transforms)

In [None]:
trainer = initialize_trainer(train_ds, val_ds)
trainer.train()

  0%|          | 1/2100 [00:06<3:57:42,  6.79s/it]

{'loss': 1.2701, 'learning_rate': 5.997142857142857e-05, 'epoch': 0.05}


  0%|          | 2/2100 [00:10<2:52:56,  4.95s/it]

{'loss': 1.2204, 'learning_rate': 5.994285714285715e-05, 'epoch': 0.1}


  0%|          | 3/2100 [00:13<2:26:40,  4.20s/it]

{'loss': 1.1782, 'learning_rate': 5.9914285714285716e-05, 'epoch': 0.14}


  0%|          | 4/2100 [00:17<2:13:59,  3.84s/it]

{'loss': 1.1475, 'learning_rate': 5.988571428571429e-05, 'epoch': 0.19}


  0%|          | 5/2100 [00:20<2:06:58,  3.64s/it]

{'loss': 1.1325, 'learning_rate': 5.9857142857142856e-05, 'epoch': 0.24}


  0%|          | 6/2100 [00:23<2:03:10,  3.53s/it]

{'loss': 1.1781, 'learning_rate': 5.982857142857143e-05, 'epoch': 0.29}


  0%|          | 7/2100 [00:26<2:01:01,  3.47s/it]

{'loss': 1.1076, 'learning_rate': 5.9800000000000003e-05, 'epoch': 0.33}


  0%|          | 8/2100 [00:30<1:59:39,  3.43s/it]

{'loss': 1.1212, 'learning_rate': 5.977142857142857e-05, 'epoch': 0.38}


  0%|          | 9/2100 [00:33<1:57:37,  3.38s/it]

{'loss': 1.0625, 'learning_rate': 5.9742857142857144e-05, 'epoch': 0.43}


  0%|          | 10/2100 [00:36<1:56:24,  3.34s/it]

{'loss': 1.101, 'learning_rate': 5.971428571428572e-05, 'epoch': 0.48}


  1%|          | 11/2100 [00:40<1:55:35,  3.32s/it]

{'loss': 1.114, 'learning_rate': 5.9685714285714284e-05, 'epoch': 0.52}


  1%|          | 12/2100 [00:43<1:55:10,  3.31s/it]

{'loss': 1.0103, 'learning_rate': 5.965714285714286e-05, 'epoch': 0.57}


  1%|          | 13/2100 [00:46<1:54:52,  3.30s/it]

{'loss': 0.997, 'learning_rate': 5.962857142857143e-05, 'epoch': 0.62}


  1%|          | 14/2100 [00:50<1:55:20,  3.32s/it]

{'loss': 0.9339, 'learning_rate': 5.96e-05, 'epoch': 0.67}


  1%|          | 15/2100 [00:53<1:54:53,  3.31s/it]

{'loss': 0.9253, 'learning_rate': 5.957142857142857e-05, 'epoch': 0.71}


  1%|          | 16/2100 [00:56<1:54:22,  3.29s/it]

{'loss': 0.9021, 'learning_rate': 5.9542857142857146e-05, 'epoch': 0.76}


  1%|          | 17/2100 [00:59<1:54:37,  3.30s/it]

{'loss': 0.9424, 'learning_rate': 5.951428571428572e-05, 'epoch': 0.81}


  1%|          | 18/2100 [01:03<1:54:37,  3.30s/it]

{'loss': 0.8931, 'learning_rate': 5.9485714285714286e-05, 'epoch': 0.86}


  1%|          | 19/2100 [01:06<1:54:48,  3.31s/it]

{'loss': 0.8871, 'learning_rate': 5.945714285714285e-05, 'epoch': 0.9}


  1%|          | 20/2100 [01:09<1:54:37,  3.31s/it]

{'loss': 0.8524, 'learning_rate': 5.9428571428571434e-05, 'epoch': 0.95}


                                                   
  1%|          | 20/2100 [02:19<1:54:37,  3.31s/it]

{'eval_loss': 1.0772571563720703, 'eval_mean_iou': 0.4197518151208867, 'eval_mean_accuracy': 0.6494130605636658, 'eval_overall_accuracy': 0.948817619910607, 'eval_per_category_iou': [3.407096983015622e-05, 0.947718486336521, 0.31150288805630905], 'eval_per_category_accuracy': [0.0012461059190031153, 0.9478396110241409, 0.9991534647478534], 'eval_runtime': 69.4116, 'eval_samples_per_second': 0.749, 'eval_steps_per_second': 0.158, 'epoch': 0.95}


  1%|          | 21/2100 [02:21<13:50:21, 23.96s/it]

{'loss': 0.8507, 'learning_rate': 5.94e-05, 'epoch': 1.0}


  1%|          | 22/2100 [02:25<10:17:37, 17.83s/it]

{'loss': 0.812, 'learning_rate': 5.937142857142857e-05, 'epoch': 1.05}


  1%|          | 23/2100 [02:28<7:41:39, 13.34s/it] 

{'loss': 0.838, 'learning_rate': 5.934285714285715e-05, 'epoch': 1.1}


  1%|          | 24/2100 [02:31<5:52:53, 10.20s/it]

{'loss': 0.8071, 'learning_rate': 5.9314285714285715e-05, 'epoch': 1.14}


  1%|          | 25/2100 [02:34<4:36:53,  8.01s/it]

{'loss': 0.8197, 'learning_rate': 5.928571428571429e-05, 'epoch': 1.19}


  1%|          | 26/2100 [02:37<3:43:50,  6.48s/it]

{'loss': 0.7535, 'learning_rate': 5.925714285714286e-05, 'epoch': 1.24}


  1%|▏         | 27/2100 [02:39<3:07:02,  5.41s/it]

{'loss': 0.7523, 'learning_rate': 5.922857142857143e-05, 'epoch': 1.29}


  1%|▏         | 28/2100 [02:42<2:40:53,  4.66s/it]

{'loss': 0.9478, 'learning_rate': 5.92e-05, 'epoch': 1.33}


  1%|▏         | 29/2100 [02:45<2:22:17,  4.12s/it]

{'loss': 0.7535, 'learning_rate': 5.917142857142857e-05, 'epoch': 1.38}


  1%|▏         | 30/2100 [02:48<2:09:34,  3.76s/it]

{'loss': 0.7151, 'learning_rate': 5.914285714285715e-05, 'epoch': 1.43}


  1%|▏         | 31/2100 [02:51<2:00:29,  3.49s/it]

{'loss': 0.6585, 'learning_rate': 5.9114285714285717e-05, 'epoch': 1.48}


  2%|▏         | 32/2100 [02:54<1:54:04,  3.31s/it]

{'loss': 0.7416, 'learning_rate': 5.9085714285714283e-05, 'epoch': 1.52}


  2%|▏         | 33/2100 [02:57<1:49:46,  3.19s/it]

{'loss': 0.7008, 'learning_rate': 5.9057142857142864e-05, 'epoch': 1.57}


  2%|▏         | 34/2100 [03:00<1:47:01,  3.11s/it]

{'loss': 0.6562, 'learning_rate': 5.902857142857143e-05, 'epoch': 1.62}


  2%|▏         | 35/2100 [03:03<1:44:32,  3.04s/it]

{'loss': 0.6443, 'learning_rate': 5.9e-05, 'epoch': 1.67}


  2%|▏         | 36/2100 [03:05<1:42:37,  2.98s/it]

{'loss': 0.6744, 'learning_rate': 5.897142857142857e-05, 'epoch': 1.71}


  2%|▏         | 37/2100 [03:08<1:41:35,  2.95s/it]

{'loss': 0.698, 'learning_rate': 5.8942857142857145e-05, 'epoch': 1.76}


  2%|▏         | 38/2100 [03:11<1:41:11,  2.94s/it]

{'loss': 0.6851, 'learning_rate': 5.891428571428572e-05, 'epoch': 1.81}


  2%|▏         | 39/2100 [03:14<1:43:30,  3.01s/it]

{'loss': 0.7036, 'learning_rate': 5.8885714285714285e-05, 'epoch': 1.86}


In [None]:
trainer.state.log_history

[{'loss': 0.6935,
  'learning_rate': 5.997142857142857e-05,
  'epoch': 0.05,
  'step': 1},
 {'loss': 0.676,
  'learning_rate': 5.994285714285715e-05,
  'epoch': 0.1,
  'step': 2},
 {'loss': 0.6713,
  'learning_rate': 5.9914285714285716e-05,
  'epoch': 0.14,
  'step': 3},
 {'loss': 0.6325,
  'learning_rate': 5.988571428571429e-05,
  'epoch': 0.19,
  'step': 4},
 {'loss': 0.6404,
  'learning_rate': 5.9857142857142856e-05,
  'epoch': 0.24,
  'step': 5},
 {'loss': 0.6362,
  'learning_rate': 5.982857142857143e-05,
  'epoch': 0.29,
  'step': 6},
 {'loss': 0.5927,
  'learning_rate': 5.9800000000000003e-05,
  'epoch': 0.33,
  'step': 7},
 {'loss': 0.5819,
  'learning_rate': 5.977142857142857e-05,
  'epoch': 0.38,
  'step': 8},
 {'loss': 0.6365,
  'learning_rate': 5.9742857142857144e-05,
  'epoch': 0.43,
  'step': 9},
 {'loss': 0.5707,
  'learning_rate': 5.971428571428572e-05,
  'epoch': 0.48,
  'step': 10},
 {'loss': 0.5955,
  'learning_rate': 5.9685714285714284e-05,
  'epoch': 0.52,
  'step':

In [None]:
# Save the trained model, so that it can be used for inference later.
# Save the log history, so that it can be used for plotting later.
trainer.save_model('models/' + model_type + '/' + crop)
with open('models/' + model_type + '/' + crop + '/log_history.json', 'w') as file:
    log_history = trainer.state.log_history
    json.dump(log_history, file)

In [None]:
outputs = trainer.predict(test_ds)

100%|██████████| 11/11 [01:13<00:00,  6.72s/it]


In [None]:
test_metric = trainer.evaluate(test_ds)

100%|██████████| 11/11 [01:18<00:00,  7.13s/it]


In [None]:
test_metric

{'eval_loss': 0.027685143053531647,
 'eval_mean_iou': 0.9514935731094303,
 'eval_mean_accuracy': 0.9739695790575061,
 'eval_overall_accuracy': 0.997216919396627,
 'eval_per_category_iou': [0.9971403494393721, 0.9058467967794884],
 'eval_per_category_accuracy': [0.99860670276049, 0.9493324553545223],
 'eval_runtime': 80.4715,
 'eval_samples_per_second': 0.659,
 'eval_steps_per_second': 0.137,
 'epoch': 20.0}

In [None]:
image = Image.open('./WE3DS/annotations/segmentation/SegmentationLabel/img_00000.png')
image.size[::-1]

(1144, 1600)

In [None]:
# upsampled_ouputs = torch.nn.functional.interpolate(
#     outputs,
#     size=image.size[::-1],
#     mode="bilinear",
#     align_corners=False,
# )
# y_pred = upsampled_ouputs.predictions.argmax(1)

AttributeError: 'PredictionOutput' object has no attribute 'dim'

In [None]:
print("Predicting following unique classes: ", np.unique(y_pred))
print(y_pred.shape)

Predicting following unique classes:  [0 1]
(53, 128, 128)


In [None]:
y_pred

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 

In [None]:
trainer.optimizer

AcceleratedOptimizer (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-08
    initial_lr: 6e-05
    lr: 4.8e-05
    weight_decay: 0.0

Parameter Group 1
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-08
    initial_lr: 6e-05
    lr: 4.8e-05
    weight_decay: 0.0
)