In [1]:
import pandas as pd
from changeit3d.utils.basics import make_train_test_val_splits

%load_ext autoreload
%autoreload 2


#### If you have downloaded our pretrained weights, you should be able to find the output of running this notebook here:
    <top-downloaded-dir>/shapetalk/misc/shapetalk_preprocessed_public_utters_for_listening_oracle_version_0.csv


In [2]:
loads = [0.925, 0.025, 0.05]  # since this will be used as an oracle model; the more training data it has seen the better
random_seed = 2023
shapetalk_file = '../../data/shapetalk/language/shapetalk/shapetalk_preprocessed_public_version_0.csv'
out_save_csv_file = '../../data/shapetalk/misc/shapetalk_preprocessed_public_utters_for_listening_oracle_version_0.csv'
save_output = False

In [3]:
df = pd.read_csv(shapetalk_file)
print('Total collected entries/utterances', len(df))
df = df[df.listening_split != 'ignore']
df.reset_index(inplace=True, drop=True)

Total collected entries/utterances 536596


In [4]:
all_model_uid = set(df.target_uid.unique())
all_model_uid = all_model_uid.union(df.source_uid.unique())
print(len(all_model_uid))
models_df = pd.DataFrame(all_model_uid, columns=['model_uid'])
models_df['shape_class'] = models_df.model_uid.apply(lambda x: x.split('/')[0] )
models_df.head()

36391


Unnamed: 0,model_uid,shape_class
0,flowerpot/ShapeNet/fa9e6aa9d67d5329e5c3c728484...,flowerpot
1,chair/ShapeNet/fb912528e642f6ea7c7cfdf5546967dd,chair
2,table/ShapeNet/ff1c8d1e157f3b74b0ceed2c36e897b8,table
3,airplane/ShapeNet/165c4491d10067b3bd46d022fd7d...,airplane
4,sofa/ShapeNet/bd5bc3e6189f6972eff42b9e13c388bc,sofa


In [5]:
# split in EVERY class independently the requested percentages
all_splitted_dfs = []
for shape_class in models_df.shape_class.unique():
    sub_df = models_df[models_df.shape_class == shape_class].copy()
    sub_df.reset_index(inplace=True, drop=True)
    sub_df = make_train_test_val_splits(sub_df, loads, random_seed, split_column='model_uid', verbose=False)
    all_splitted_dfs.append(sub_df)
    print(f"{shape_class: <10}", 
          "train: {:5d}  val: {:5d}  test: {:5d}".format((sub_df.split == "train").sum(),
                                                           (sub_df.split == "val").sum(),
                                                           (sub_df.split == "test").sum(),
                                                          ))
result = pd.concat(all_splitted_dfs)
result.reset_index(inplace=True, drop=True)

flowerpot  train:   576  val:    16  test:    31
chair      train:  6112  val:   166  test:   330
table      train:  7562  val:   205  test:   409
airplane   train:  2517  val:    69  test:   136
sofa       train:  2835  val:    77  test:   153
display    train:  1085  val:    30  test:    59
bench      train:  1532  val:    42  test:    83
guitar     train:   696  val:    19  test:    38
bottle     train:   455  val:    12  test:    25
lamp       train:  2138  val:    58  test:   116
cap        train:   192  val:     6  test:    10
dresser    train:  1563  val:    43  test:    84
vase       train:   761  val:    21  test:    41
bed        train:   690  val:    20  test:    37
plant      train:   259  val:     8  test:    14
bag        train:   128  val:     4  test:     7
mug        train:   192  val:     6  test:    10
bookshelf  train:   754  val:    21  test:    41
faucet     train:   591  val:    17  test:    32
pistol     train:   279  val:     8  test:    15
person     train:   

In [6]:
merged = df.merge(result, left_on='target_uid', right_on='model_uid')
assert all(merged.shape_class == merged.target_object_class)
assert all(merged.model_uid == merged.target_uid)

In [7]:
merged = merged.drop(columns=['target_unary_split', 'source_unary_split', 
                              'listening_split', 'changeit_split', 
                              'shape_class', 'model_uid'])
merged = merged.rename(columns={'split':'listening_split'})

In [8]:
for split in ['train', 'test', 'val']:
    print((merged.listening_split == split).mean())

0.9244131261520673
0.05005837766406752
0.02552849618386524


In [9]:
train_targets = set(merged[merged.listening_split == 'train']['target_uid'])
test_targets = set(merged[merged.listening_split == 'test']['target_uid'])
val_targets = set(merged[merged.listening_split == 'val']['target_uid'])

assert len(train_targets.intersection(test_targets)) == 0 
assert len(train_targets.intersection(val_targets)) == 0 
assert len(test_targets.intersection(val_targets)) == 0 

In [10]:
if save_output:
    merged.to_csv(out_save_csv_file , index=False)