## Import libraries, methods and constants

In [1]:
import numpy as np
from PIL import Image
import datasets
import json
import torch
import codecs
import os
from os import sys

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)
from data_prepossessing import create_datasets_for_plants, get_labels
from data_visualization import visualize_annotation, visualize_annotation_for_image
from constants import crop_indices, weed_indices

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # use GPU if available, otherwise use a CPU
print(device)

cuda


In [4]:
print("Crop indicies: ", crop_indices)
print("Weed indicies: ", weed_indices)

Crop indicies:  [2, 5, 6, 11, 14, 15, 18]
Weed indicies:  [3, 4, 7, 8, 9, 10, 12, 13, 16, 17]


In [5]:
crop = 'broad_bean'
# crop = 'common_buckwheat'
# crop = 'pea'
# crop = 'corn'
# crop = 'soybean'
# crop = 'sunflower'
# crop = 'sugar_beet'

# model_type = 'multiclass'
model_type = 'binary'

model_plant_names = [crop] + weed_plants
print(model_plant_names)

['broad_bean', 'corn_spurry', 'red-root_amaranth', 'red_fingergrass', 'common_wild_oat', 'cornflower', 'corn_cockle', 'milk_thistle', 'rye_brome', 'narrow-leaved_plantain', 'small-flower_geranium']


In [6]:
train_ds, val_ds, test_ds = create_datasets_for_plants(model_plant_names, model_type, crop)

['img_00173.png', 'img_00174.png', 'img_00175.png', 'img_00176.png', 'img_00177.png', 'img_00178.png', 'img_00672.png', 'img_00673.png', 'img_00674.png', 'img_00675.png', 'img_00676.png', 'img_00677.png', 'img_00678.png', 'img_00679.png', 'img_00680.png', 'img_00681.png', 'img_00682.png', 'img_00683.png', 'img_00684.png', 'img_00882.png', 'img_00883.png', 'img_00884.png', 'img_00885.png', 'img_00886.png', 'img_00887.png', 'img_00938.png', 'img_00980.png', 'img_00981.png', 'img_00982.png', 'img_00983.png', 'img_00984.png', 'img_00985.png', 'img_00986.png', 'img_00987.png', 'img_00988.png', 'img_00989.png', 'img_01070.png', 'img_01071.png', 'img_01072.png', 'img_01073.png', 'img_01074.png', 'img_01075.png', 'img_01076.png', 'img_01077.png', 'img_01078.png', 'img_01079.png', 'img_01219.png', 'img_01220.png', 'img_01221.png', 'img_01222.png', 'img_01223.png', 'img_01224.png', 'img_01225.png', 'img_01226.png', 'img_01227.png', 'img_01228.png', 'img_01279.png', 'img_01280.png', 'img_01281.pn

In [7]:
from transformers import SegformerImageProcessor

checkpoint = "nvidia/mit-b0"
image_processor = SegformerImageProcessor.from_pretrained(checkpoint)
image_processor



SegformerImageProcessor {
  "do_normalize": true,
  "do_reduce_labels": false,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "SegformerFeatureExtractor",
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "SegformerImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 512,
    "width": 512
  }
}

In [8]:
id2label, label2id = get_labels(crop)

print('Number of classes:', len(id2label))
print('id2label:', id2label)
print('label2id:', label2id)

labels: ['void', 'soil', 'broad_bean', 'weeds']
ids: [0, 1, 2, 3]
num_labels: 4
id2label: {0: 'void', 1: 'soil', 2: 'broad_bean', 3: 'weeds'}
label2id: {'void': 0, 'soil': 1, 'broad_bean': 2, 'weeds': 3}


In [9]:
from transformers import AutoModelForSemanticSegmentation
checkpoint = 'models/' + model_type + '/' + crop + '/'
model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)
model.to(device)

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

In [10]:
image=test_ds[0]["image"]
encoding = image_processor(image, return_tensors="pt")
pixel_values = encoding.pixel_values.to(device)

In [11]:
outputs = model(pixel_values=pixel_values)

In [12]:
logits = outputs.logits.cpu()

In [13]:
upsampled_logits = torch.nn.functional.interpolate(
    logits,
    size=image.size[::-1],
    mode="bilinear",
    align_corners=False,
)

pred_seg = upsampled_logits.argmax(dim=1)[0]

In [14]:
print("Predicting following unique classes: ", np.unique(pred_seg))
print("The resolution of predicted image is ", pred_seg.shape)
print("which should be equal to the resolution of the original image", image.size[::-1])

Predicting following unique classes:  [1 3]
The resolution of predicted image is  torch.Size([1144, 1600])
which should be equal to the resolution of the original image (1144, 1600)


In [19]:
print(type(pred_seg))
print(type(pred_seg.detach().cpu().numpy()))

<class 'torch.Tensor'>
<class 'numpy.ndarray'>


In [None]:
visualize_annotation_for_image(pred_seg)

In [None]:
visualize_annotation(pred_seg)