In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
import PIL
from PIL import Image
import requests
from transformers import pipeline
import datasets
from datasets import load_dataset
import os
import evaluate
import torch
import cv2
import json
import codecs

  from .autonotebook import tqdm as notebook_tqdm


## Split WE3DS dataset per crop type

In [2]:
im = cv2.imread('./WE3DS/annotations/segmentation/SegmentationLabel/img_00002.png')
np.unique(im)

array([ 0,  1, 11], dtype=uint8)

In [3]:
image_folder = './WE3DS/images/'
annotation_folder = './WE3DS/annotations/segmentation/SegmentationLabel/'
# Define the paths to the images and annotations
image_paths = np.array(os.listdir(image_folder))
annotation_paths = np.array(os.listdir(annotation_folder))

In [4]:
plant_classification = {
    'void': 'void',
    'soil': 'soil',
    'broad bean': 'crop',
    'corn spurry': 'weed',
    'red-root amaranth': 'weed',
    'common buckwheat': 'crop',
    'pea': 'crop',
    'red fingergrass': 'weed',
    'common wild oat': 'weed',
    'cornflower': 'weed',
    'corn cockle': 'weed',
    'corn': 'crop',
    'milk thistle': 'weed',
    'rye brome': 'weed',
    'soybean': 'crop',
    'sunflower': 'crop',
    'narrow-leaved plantain': 'weed',
    'small-flower geranium': 'weed',
    'sugar beet': 'crop'
}

crop_indices = [index for index, value in enumerate(plant_classification) if plant_classification[value] == 'crop']
weed_indices = [index for index, value in enumerate(plant_classification) if plant_classification[value] == 'weed']

print(crop_indices)
print(weed_indices)

[2, 5, 6, 11, 14, 15, 18]
[3, 4, 7, 8, 9, 10, 12, 13, 16, 17]


In [5]:
output_file_path = './WE3DS_class_presence.json'
we3ds_class_presence = np.array([])

if not os.path.isfile(output_file_path):
    for image_path, annotation_path in zip(image_paths, annotation_paths):
        im = cv2.imread(annotation_folder + annotation_path)
        unique_image_classes = np.unique(im)   
        image_class_pressence = {
            'image_path': image_path,
            'class_presence': np.zeros(len(plant_classification)).tolist()
        }

        for unique_image_class in unique_image_classes:
            image_class_pressence['class_presence'][unique_image_class] = 1

        we3ds_class_presence = np.append(we3ds_class_presence, image_class_pressence)
    
    with open(output_file_path, 'w') as file:
        we3ds_class_presence_as_list = we3ds_class_presence.tolist()
        json.dump(we3ds_class_presence_as_list, file)

else:
    we3ds_class_presence = json.load(codecs.open(output_file_path, 'r', 'utf-8-sig'))

In [6]:
we3ds_class_presence

[{'image_path': 'img_00000.png',
  'class_presence': [1,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0]},
 {'image_path': 'img_00001.png',
  'class_presence': [0.0,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0]},
 {'image_path': 'img_00002.png',
  'class_presence': [1,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0]},
 {'image_path': 'img_00003.png',
  'class_presence': [0.0,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0]},
 {'image_path': 'img_00004.png',
  'class_presence': [0.0,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   1,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0]},

In [7]:
list(plant_classification.keys())

['void',
 'soil',
 'broad bean',
 'corn spurry',
 'red-root amaranth',
 'common buckwheat',
 'pea',
 'red fingergrass',
 'common wild oat',
 'cornflower',
 'corn cockle',
 'corn',
 'milk thistle',
 'rye brome',
 'soybean',
 'sunflower',
 'narrow-leaved plantain',
 'small-flower geranium',
 'sugar beet']

In [8]:
all_crop_images = []
for crop_index in crop_indices:
    crop_images = [entry['image_path'] for entry in we3ds_class_presence if entry['class_presence'][crop_index] == 1]
    print(list(plant_classification.keys())[crop_index], len(crop_images))
    all_crop_images = np.append(all_crop_images, crop_images)

print("Number of images that contain at least one crop: ", len(all_crop_images))
print("Number of images that contain only one crop: ", len(np.unique(all_crop_images)))

broad bean 210
common buckwheat 137
pea 207
corn 403
soybean 303
sunflower 135
sugar beet 410
Number of images that contain at least one crop:  1805
Number of images that contain only one crop:  1803


In [9]:
unique_images, counts = np.unique(all_crop_images, return_counts=True)
duplicate_images = unique_images[counts > 1]
print("Duplicate images: ", duplicate_images)
for image in duplicate_images:
    im = cv2.imread(annotation_folder + image)
    print(np.unique(im))

Duplicate images:  ['img_01096.png' 'img_01098.png']
[ 0  1  6 15]


[ 0  1  6 15]


In [10]:
image_paths_set = set(image_paths)
image_count_set = set(all_crop_images)

# Find the difference between the sets
images_with_no_crops = image_paths_set - image_count_set

# Convert the result back to a list
images_with_no_crops_as_list = list(images_with_no_crops)

print(len(image_paths_set))
print(len(image_count_set))
print(len(images_with_no_crops_as_list))
print(len(image_count_set) + len(images_with_no_crops_as_list))

2568
1803
765
2568


In [11]:
all_weed_images = []
for weed_index in weed_indices:
    weed_images = [entry['image_path'] for entry in we3ds_class_presence if entry['class_presence'][weed_index] == 1]
    print(list(plant_classification.keys())[weed_index], len(weed_images))
    all_weed_images = np.append(all_weed_images, weed_images)

corn spurry 20
red-root amaranth 34
red fingergrass 8
common wild oat 44
cornflower 286
corn cockle 277
milk thistle 226
rye brome 60
narrow-leaved plantain 22
small-flower geranium 156


In [13]:
for crop_index in crop_indices:
    crop_images = [entry['image_path'] for entry in we3ds_class_presence if entry['class_presence'][crop_index] == 1]
    print(list(plant_classification.keys())[crop_index], len(crop_images))

    for weed_index in weed_indices:
        weed_images = [entry['image_path'] for entry in we3ds_class_presence if entry['class_presence'][weed_index] == 1]
        weed_in_crop_images = np.intersect1d(crop_images, weed_images)
        print("   ", list(plant_classification.keys())[weed_index], len(weed_in_crop_images))

print("Images with no crops ", len(images_with_no_crops_as_list))
for weed_index in weed_indices:
    weed_images = [entry['image_path'] for entry in we3ds_class_presence if entry['class_presence'][weed_index] == 1]
    weed_in_crop_images = np.intersect1d(images_with_no_crops_as_list, weed_images)
    print("   ", list(plant_classification.keys())[weed_index], len(weed_in_crop_images))


broad bean 210
    corn spurry 0
    red-root amaranth 0
    red fingergrass 0
    common wild oat 0
    cornflower 0
    corn cockle 0
    milk thistle 0
    rye brome 0
    narrow-leaved plantain 0
    small-flower geranium 0
common buckwheat 137
    corn spurry 0
    red-root amaranth 0
    red fingergrass 0
    common wild oat 0
    cornflower 0
    corn cockle 0
    milk thistle 2
    rye brome 0
    narrow-leaved plantain 0
    small-flower geranium 0
pea 207
    corn spurry 0
    red-root amaranth 0
    red fingergrass 0
    common wild oat 0
    cornflower 0
    corn cockle 0
    milk thistle 1
    rye brome 0
    narrow-leaved plantain 0
    small-flower geranium 0
corn 403
    corn spurry 0
    red-root amaranth 18
    red fingergrass 0
    common wild oat 0
    cornflower 58
    corn cockle 0
    milk thistle 0
    rye brome 0
    narrow-leaved plantain 0
    small-flower geranium 49
soybean 303
    corn spurry 0
    red-root amaranth 5
    red fingergrass 0
    common wild 

In [15]:
rye_brome_images = [entry['image_path'] for entry in we3ds_class_presence if entry['class_presence'][13] == 1]
print(rye_brome_images)
print(len(rye_brome_images))

['img_00937.png', 'img_00960.png', 'img_00961.png', 'img_00962.png', 'img_00963.png', 'img_00964.png', 'img_00965.png', 'img_00966.png', 'img_00967.png', 'img_00968.png', 'img_00969.png', 'img_01050.png', 'img_01051.png', 'img_01052.png', 'img_01053.png', 'img_01054.png', 'img_01055.png', 'img_01056.png', 'img_01057.png', 'img_01058.png', 'img_01059.png', 'img_01149.png', 'img_01150.png', 'img_01151.png', 'img_01152.png', 'img_01153.png', 'img_01154.png', 'img_01155.png', 'img_01156.png', 'img_01157.png', 'img_01158.png', 'img_01199.png', 'img_01200.png', 'img_01201.png', 'img_01202.png', 'img_01203.png', 'img_01204.png', 'img_01205.png', 'img_01206.png', 'img_01207.png', 'img_01208.png', 'img_01269.png', 'img_01270.png', 'img_01271.png', 'img_01272.png', 'img_01273.png', 'img_01274.png', 'img_01275.png', 'img_01276.png', 'img_01277.png', 'img_01278.png', 'img_01769.png', 'img_01770.png', 'img_01771.png', 'img_01772.png', 'img_01773.png', 'img_01774.png', 'img_01775.png', 'img_01776.pn

## Prepare images as input for transformers model

In [None]:
# Create an empty list to store the dataset
images_list = []

# Iterate over the image and annotation paths
for image_path, annotation_path in zip(image_paths, annotation_paths):
    # Load the image and annotation using PIL
    image = Image.open(image_folder + image_path)
    annotation = Image.open(annotation_folder + annotation_path)
    
    # Create a dictionary entry for the dataseta
    entry = {'image': image, 'annotation': annotation}
    
    # Add the entry to the dataset
    images_list.append(entry)

In [None]:
filename = images_list[0]['image']

In [None]:
# ds = load_dataset("scene_parse_150", split="train[:50]")
dataset = datasets.Dataset.from_list(images_list)

In [None]:
dataset = dataset.train_test_split(test_size=0.2)
train_ds = dataset["train"]
test_ds = dataset["test"]

In [None]:
train_ds[12]