In [1]:
from s_clip_scripts.main import main
from s_clip_scripts.params import parse_args
from s_clip_scripts.data_loader import get_data
from s_clip_scripts.model import create_custom_model

from open_clip import create_model_and_transforms, get_tokenizer, create_loss

from itertools import chain
from textblob import TextBlob
import nltk
import re
import copy

nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

def parse_str_args(str_args):
    str_args = str_args.split('\n')
    str_args = [s.strip() for s in str_args]
    str_args = [s.split(' ') for s in str_args]
    str_args = list(chain(*str_args))
    str_args = [s for s in str_args if len(s) > 0]
    args = parse_args(str_args)
    return args

def get_model_and_data(args):
    args.device = 'cpu' 
    model, preprocess_train, preprocess_val = create_model_and_transforms(
            args.model, args.pretrained, precision=args.precision, device=args.device, output_dict=True,
            aug_cfg = args.aug_cfg, )
    model = create_custom_model(args, model)
    data = get_data(args, (preprocess_train, preprocess_val), tokenizer=get_tokenizer(args.model))
    return model, data

def format_labels(classnames, dataset):
    formatted_classnames = {}
    for c in classnames:
        original_c = copy.deepcopy(c)
        # For UCM, the labels are not with pascal case, we make it consistent with RSICD's pascal case here
        c = c.replace('residential', 'Residential')
        c = c.replace('mobilehomepark', 'MobileHomePark')
        c = c.replace('tenniscourt', 'TennisCourt')
        c = c.replace('parkinglot', 'ParkingLot')
        c = c.replace('baseballdiamond', 'BaseballDiamond')
        c = c.replace('golfcourse', 'GolfCourse')
        c = c.replace('storagetank', 'StorageTank')
        
        # RS classes use pascal case (e.g. BareLand), here we split on the capitals (e.g. Bare Land)
        # From: https://stackoverflow.com/questions/2277352/split-a-string-at-uppercase-letters
        c = re.sub( r"([A-Z])", r" \1", c)
        # Check if this is necessary: replace '&' with 'and', then replace ' and ' with ' or '
        c = c.replace('&', 'and')
        c = c.replace(' and ', ' or ')
        
        # Make every word singular (e.g. fields -> field)
        c = TextBlob(c).words
        c = c.singularize() # Returns a list of words
        c = (' '.join(c)).lower() # Make a string of the list, with spaces, and remove capitalization

        c = c.replace('jewellery', 'jewelry') # Keep spelling consistent
        
        # Replace faulty singularization, such as 'ties' -> 'ty', fixing that here
        c = c.replace('ty', 'tie')
        c = c.replace('jean', 'jeans')
        c = c.replace('glass', 'glasses')
        c = c.replace('legging', 'leggings')
        c = c.replace('pant', 'pants')
        c = c.replace('bottom', 'bottoms')
        c = c.replace('overpas', 'overpass')
        c = c.replace('tenni', 'tennis') 

        # Mass nouns are nouns which don't use an article
        mass_nouns = set(['lingerie', 'jewelry', 'swimwear', 'underwear', 'outerwear', 'eyewear'])
        contains_mass_noun = len(mass_nouns.intersection(c.split())) > 0 
        
        # For *remote-sensing* descriptors like 'agricultural', 'residential', (that end in 'al') we add 'area' after it
        if c.endswith('al') or c == 'parking' and 'Fashion' not in dataset: 
            c += ' area'
        # If the last word is not plural (check the tag) and not a mass noun, we add an article (a or an)
        # .tags values have format (noun, tag), so we get the last index of the tuple
        if TextBlob(c).tags[-1][-1] != 'NNS' and not contains_mass_noun: 
            if c[0] in 'aeiou': # If the first word in the label starts with a vowel, we use 'an'
                c = 'an ' + c
            else: # Otherwise, use 'a' for the article
                c = 'a ' + c
        formatted_classnames[original_c] = c
    return formatted_classnames

[nltk_data] Downloading package punkt to /home/nhollain/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/nhollain/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [2]:
zeroshot_datasets = ["Fashion200k-SUBCLS", "Fashion200k-CLS", "FashionGen-CLS", "FashionGen-SUBCLS", "Polyvore-CLS", "RSICD-CLS", "UCM-CLS"]
for dataset in zeroshot_datasets:
    str_args = '--imagenet-val {}'.format(dataset)
    args = parse_str_args(str_args)
    model, data = get_model_and_data(args)
    formatted_classnames = format_labels(data['classnames'], dataset)
    # print('Original classnames', data['classnames'])
    print(formatted_classnames)

Fashion200k-SUBCLS (split: val)
CLS size: 29789
{'blazers and suit jackets': 'a blazer or suit jacket', 'blouses': 'a blouse', 'cargo pants': 'cargo pants', 'casual and day dresses': 'a casual or day dress', 'casual jackets': 'a casual jacket', 'cocktail dresses': 'a cocktail dress', 'cropped pants': 'cropped pants', 'denim jackets': 'a denim jacket', 'full length pants': 'full length pants', 'fur jackets': 'a fur jacket', 'gowns': 'a gown', 'harem pants': 'harem pants', 'knee length skirts': 'a knee length skirt', 'leather jackets': 'a leather jacket', 'leggings': 'leggings', 'long sleeved tops': 'a long sleeved top', 'maxi and long dresses': 'a maxi or long dress', 'maxi skirts': 'a maxi skirt', 'mid length skirts': 'a mid length skirt', 'mini and short dresses': 'a mini or short dress', 'mini skirts': 'a mini skirt', 'padded and down jackets': 'a padded or down jacket', 'prom and formal dresses': 'a prom or formal dress', 'shirts': 'a shirt', 'short sleeve tops': 'a short sleeve top