In [1]:
import matplotlib.pyplot as plt
from tqdm import tqdm, tqdm_notebook
from PIL import Image
import numpy as np
import pandas as pd
import torch
import os
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
RESCALE_SIZE = 320
DATA_MODES = ['train', 'val', 'test']

In [3]:
# path to folders with images
folder_base = '/kaggle/input/2022-ukraine-russia-war-equipment-losses-oryx/'
folder_img_ru = ['img_russia/img_russia/', 'img_russia_2024-07-12/img_russia/']
folder_img_ua = ['img_ukraine/img_ukraine/', 'img_ukraine_2024-07-12/img_ukraine/']

# path to images metadata
path_img_metadata_ru = '/kaggle/input/2022-ukraine-russia-war-equipment-losses-oryx/img_russia_losses_metadata.csv'
path_img_metadata_ua = '/kaggle/input/2022-ukraine-russia-war-equipment-losses-oryx/img_ukraine_losses_metadata.csv'

# path to csv-files
path_losses_ru = '/kaggle/input/2022-ukraine-russia-war-equipment-losses-oryx/losses_russia.csv'
path_losses_ua = '/kaggle/input/2022-ukraine-russia-war-equipment-losses-oryx/losses_ukraine.csv'

directory_path = '/kaggle/working/model_svs/'

In [4]:
target_name = 'equipment'

In [5]:
df = pd.read_csv(path_img_metadata_ru)

In [6]:
df_filtered = df[df['file'].str.contains(r'\bcapt\b', case=False, na=False)]

In [7]:
capt_df = df_filtered[df_filtered.groupby(target_name)[target_name].transform('count') > 19]

In [8]:
class_counts = capt_df[target_name].value_counts()
val_data = pd.DataFrame()
train_data = capt_df.copy()

In [10]:
class_counts

equipment
Infantry_Fighting_Vehicles                 175
Trucks,_Vehicles,_and_Jeeps                150
Tanks                                      133
Armoured_Fighting_Vehicles                  74
Engineering_Vehicles_And_Equipment          67
Reconnaissance_Unmanned_Aerial_Vehicles     59
Towed_Artillery                             24
Name: count, dtype: int64

In [126]:
# Drop rows where 'model' column contains the word 'Unknown' (case insensitive)
df_cleaned = df[~df[target_name].str.contains('unknown', case=False, na=False)]

In [127]:
df_cleaned = df_cleaned.reset_index(drop=True)

In [240]:
for cls, count in class_counts.items():
    if count <= 5:
        continue
    elif count <= 10:
        n_samples = 2
    elif count <= 20:
        n_samples = 3
    elif count <= 100:
        n_samples = round(count * 0.1)
    else:
        n_samples = round(count * 0.1)

    # Select n_samples from the class for validation
    val_samples = train_data[train_data[target_name] == cls].sample(n_samples, random_state=42)
    
    # Append to validation set
    val_data = pd.concat([val_data, val_samples])
    
    # Remove selected instances from train set
    train_data = train_data.drop(val_samples.index)

In [241]:
print(f"Validation set size: {len(val_data)}")
print(f"Training set size: {len(train_data)}")

Validation set size: 68
Training set size: 614


In [242]:
from collections import defaultdict
from itertools import combinations

In [243]:
def custom_transform(img, rescale_size):
    image = img.resize((rescale_size, rescale_size), resample=Image.BILINEAR)
    return image

In [244]:
def preproc(dataframe):
    index_list = []
    image_list = []
    label_list = []
    for index, row in tqdm(dataframe.iterrows(), total=dataframe.shape[0]):
        if row['folder'] == 'img_russia':
            img_path = os.path.join(folder_base, folder_img_ru[0])
        else:
            img_path = os.path.join(folder_base, folder_img_ru[1])
    
        img_path = os.path.join(img_path, row['equipment'], row['file'])
    
        # Check if the corresponding label file exists
        if not os.path.exists(img_path):
            print(f"Label file missing for {img_path}, skipping.")
            continue
        
        try:
            # Open the image
            img = Image.open(img_path)
            if img.mode != 'RGB'
                img = img.convert('RGB')
            resized_img = custom_transform(img, RESCALE_SIZE)
            image_list.append(resized_img)
            label_list.append(row[target_name])
            index_list.append(index)
        
        except Exception as e:
            print(f"Error processing {img_path}: {e}")

    return image_list, label_list, index_list

In [245]:
train_images, train_labels, train_index = preproc(train_data)

100%|██████████| 614/614 [00:11<00:00, 51.74it/s]


In [246]:
val_images, val_labels, val_index = preproc(val_data)

100%|██████████| 68/68 [00:01<00:00, 55.89it/s]


In [247]:
from sklearn.preprocessing import LabelEncoder

In [248]:
le = LabelEncoder()
le.fit(list(set(train_labels)))

In [249]:
enc_tlabels = le.transform(train_labels)
enc_vlabels = le.transform(val_labels)

In [250]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [251]:
MEAN = [0.5097, 0.524, 0.5099]
STD = [0.212, 0.212, 0.237]

In [252]:
transform_v1 = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomRotation(degrees=25),
    transforms.RandomPerspective(distortion_scale=0.6, p=0.25),
    transforms.Normalize(MEAN, STD)
])