In [2]:
import torch
from datasets import Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
ds = Dataset.load_from_disk("dataset/test_dataset_for_vit")

In [4]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("image-classification", model="vit-base-beans", device=0, framework="pt")

In [5]:
result = pipe(ds['image'])

In [6]:
ds

Dataset({
    features: ['image', 'image_file_path'],
    num_rows: 6925
})

In [7]:
result_dict = {}
for i, filename in enumerate(ds['image_file_path']):
    result_dict[filename] = result[i][0]['label']

In [9]:
result_dict

{'w_0.jpg': 'whale_38681',
 'w_10.jpg': 'whale_95370',
 'w_10000.jpg': 'whale_38681',
 'w_10001.jpg': 'whale_38681',
 'w_10002.jpg': 'whale_38681',
 'w_10003.jpg': 'whale_35004',
 'w_10004.jpg': 'whale_38681',
 'w_10005.jpg': 'whale_38681',
 'w_10006.jpg': 'whale_38681',
 'w_10007.jpg': 'whale_38681',
 'w_10008.jpg': 'whale_38681',
 'w_10009.jpg': 'whale_95370',
 'w_1001.jpg': 'whale_38681',
 'w_10010.jpg': 'whale_95370',
 'w_10011.jpg': 'whale_38681',
 'w_10012.jpg': 'whale_38681',
 'w_10013.jpg': 'whale_38681',
 'w_10014.jpg': 'whale_35004',
 'w_10015.jpg': 'whale_38681',
 'w_10016.jpg': 'whale_23525',
 'w_10017.jpg': 'whale_38681',
 'w_10018.jpg': 'whale_23525',
 'w_10019.jpg': 'whale_38681',
 'w_1002.jpg': 'whale_95370',
 'w_10020.jpg': 'whale_38681',
 'w_10021.jpg': 'whale_38681',
 'w_10022.jpg': 'whale_38681',
 'w_10023.jpg': 'whale_38681',
 'w_10024.jpg': 'whale_38681',
 'w_10025.jpg': 'whale_95370',
 'w_10026.jpg': 'whale_38681',
 'w_10027.jpg': 'whale_38681',
 'w_10028.jpg': '

In [18]:
import csv

def generate_submission_csv(csv_file, image_whale_dict, updated_csv_file):
    """
    Update a CSV file with image-to-whale mappings and write the updates to a new CSV file.

    Args:
    - csv_file (str): Path to the CSV file.
    - image_whale_dict (dict): Dictionary containing image-to-whale mappings.
    - updated_csv_file (str): Path to the new CSV file to write the updates.

    Returns:
    - None
    """
    # Read the CSV file
    with open(csv_file, 'r') as file:
        reader = csv.reader(file)
        rows = list(reader)

    # Get the header row
    header = rows[0]
    
    # Find the index of the columns with whale IDs
    whale_id_indices = {header[i]: i for i in range(1, len(header))}

    # For each row in the CSV file
    for row in rows[1:]:
        image_name = row[0]
        if image_name in image_whale_dict:
            # Get the whale ID from the dictionary
            whale_id = image_whale_dict[image_name]
            
            # Update the corresponding whale ID column to 1 and others to 0
            for col_name, col_index in whale_id_indices.items():
                if col_name == f'whale_{whale_id}':
                    row[col_index] = '1'
                else:
                    row[col_index] = '0'

    # Write the modified data to the new CSV file
    with open(updated_csv_file, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerows(rows)


# CSV file path
csv_file = 'dataset/sample_submission.csv'
output_csv_file = 'submissions/test_vision_transformer.csv'

# Dictionary mapping image file names to whale IDs
image_whale_dict = result_dict

# Call the function to update the CSV file
generate_submission_csv(csv_file, image_whale_dict, output_csv_file)