In [10]:
import numpy as np
import pandas as pd
import torch
import argparse
import evaluate
from datasets import load_dataset
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer, AutoImageProcessor, ResNetForImageClassification

In [6]:
model_name = 'microsoft/resnet-50'
feature_extractor = AutoImageProcessor.from_pretrained(model_name)

Downloading (…)rocessor_config.json:   0%|          | 0.00/266 [00:00<?, ?B/s]

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [5]:
# ----------------
# functions
# ----------------
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    return metric.compute(
        predictions=np.argmax(eval_pred.predictions, axis=1), 
        references=eval_pred.label_ids
    )

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [7]:
## prepare data
print('-- Load Data...')
dataset = load_dataset("imagefolder", data_dir="planet-imgs-original/split1/")
prepared_ds = dataset.with_transform(transform)

-- Load Data...


Resolving data files:   0%|          | 0/125 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/40 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/128 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/35 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/43 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [8]:
## prepare labels
labels = [0,1]
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

In [11]:
## load model
print('-- Load Model...')
model = ResNetForImageClassification.from_pretrained(
    model_name,
    num_labels=2,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True
)

-- Load Model...


Downloading (…)lve/main/config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/103M [00:00<?, ?B/s]

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([2, 2048]) in the model instantiated
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
## training args
s=0
output_path = model_name.split("/")[-1]+f'/split{s+1}'
training_args = TrainingArguments(
    output_dir=output_path,
    per_device_train_batch_size=2,
    evaluation_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",    
    eval_steps=20,
    save_steps=20,
    logging_steps=20,
    num_train_epochs=10,
    #fp16=True,
    learning_rate=5e-6,
    save_total_limit=1,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
) 

In [16]:
## trainer
print('-- Training...')
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)    

-- Training...


In [None]:
## training
train_results = trainer.train()

***** Running training *****
  Num examples = 124
  Num Epochs = 10
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 620
  Number of trainable parameters = 23512130


Step,Training Loss,Validation Loss,Accuracy
20,0.6898,0.695154,0.516129
40,0.6928,0.692773,0.580645
60,0.7015,3.897073,0.483871


***** Running Evaluation *****
  Num examples = 31
  Batch size = 8
Saving model checkpoint to resnet-50/split1/checkpoint-20
Configuration saved in resnet-50/split1/checkpoint-20/config.json
Model weights saved in resnet-50/split1/checkpoint-20/pytorch_model.bin
Image processor saved in resnet-50/split1/checkpoint-20/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 31
  Batch size = 8
Saving model checkpoint to resnet-50/split1/checkpoint-40
Configuration saved in resnet-50/split1/checkpoint-40/config.json
Model weights saved in resnet-50/split1/checkpoint-40/pytorch_model.bin
Image processor saved in resnet-50/split1/checkpoint-40/preprocessor_config.json
Deleting older checkpoint [resnet-50/split1/checkpoint-20] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 31
  Batch size = 8
Saving model checkpoint to resnet-50/split1/checkpoint-60
Configuration saved in resnet-50/split1/checkpoint-60/config.json
Model weights saved in resnet-