In [25]:

import torch
from jutils import open_img_id, img_exists, SquarePad, onehot, labelstring

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from xgboost.sklearn import XGBClassifier
from sklearn.multioutput import ClassifierChain
import torchvision.transforms.functional as fn


from transformers import BeitFeatureExtractor, BeitForImageClassification
from PIL import Image

feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')


# read in images
padder = SquarePad()
df = pd.read_csv('train_Kea.csv')
# throw away missing images
df = df.loc[df.image_id.apply(img_exists)]
df['Images'] = df['image_id'].apply(open_img_id).apply(padder)

import torchvision.transforms as T
import torchvision.transforms.functional as fn
transforms = [
    fn.hflip,
    T.Compose([T.GaussianBlur(9)]),
    T.Compose([fn.hflip, T.GaussianBlur(9)]),
    T.Compose([fn.hflip, T.RandomAdjustSharpness(0.1, p=1)]),
    fn.vflip,
    T.Compose([fn.vflip, T.GaussianBlur(9)]),
    T.Compose([fn.vflip, T.RandomAdjustSharpness(0.1, p=1)]),
]
augmented = []

for t in transforms:
    df2 = df.copy()
    df2['Images'] = df2['Images'].apply(t)
    augmented.append(df2)

df = pd.concat([df]+augmented)

labelsdf = pd.read_csv('labels.csv')
labels = labelsdf['object'].values.tolist()
y_train = np.array([ onehot(lbl) for lbl in df['labels'] ]).astype(int)
inputs = feature_extractor(images=df['Images'].tolist(), return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)
X_train = outputs['logits']
logres = LogisticRegression(dual=True, solver='liblinear', random_state=342985, max_iter=400, class_weight='balanced')
final = ClassifierChain(logres)
final.fit(X_train, y_train) #logistic regression


testdf = pd.read_csv('test.csv')
testlabels = []
labelsdf = pd.read_csv('labels.csv')
for img_id in testdf.image_id:
    try:
        inputs = feature_extractor([open_img_id(img_id)], return_tensors="pt")
        with torch.no_grad():
            x = model(**inputs)
        prediction = final.predict(x['logits'])
        predicted_labels = labelstring(prediction.astype(bool))

        if len(predicted_labels) == 0:
            testlabels.append('l1')
        else:
            testlabels.append(predicted_labels)
        print(img_id,
              ' '.join(labelsdf.loc[labelsdf.label_id.isin(testlabels[-1].split(' ')), 'object'].values.ravel()),
              sep='\t')
    except FileNotFoundError:
        print(img_id, 'missing, defaulting to l0')
        testlabels.append('l0')

testdf['labels'] = testlabels
testdf.to_csv('joosep_submissions/kea_beit_wscraped_modified.csv', index=False) # Test mfscore 0.56


