DATASETS Library

In [1]:
from datasets import load_dataset
data_files = {'train': 'data/drugsComTrain_raw.tsv', 'test': 'data/drugsComTest_raw.tsv'}
drug_dataset = load_dataset('csv', data_files=data_files, delimiter='\t')

In [2]:
drug_dataset.shape

{'train': (161297, 7), 'test': (53766, 7)}

In [3]:
drug_dataset['train'].features

{'Unnamed: 0': Value(dtype='int64', id=None),
 'drugName': Value(dtype='string', id=None),
 'condition': Value(dtype='string', id=None),
 'review': Value(dtype='string', id=None),
 'rating': Value(dtype='float64', id=None),
 'date': Value(dtype='string', id=None),
 'usefulCount': Value(dtype='int64', id=None)}

In [4]:
drug_dataset['train'][:3]

{'Unnamed: 0': [206461, 95260, 92703],
 'drugName': ['Valsartan', 'Guanfacine', 'Lybrel'],
 'condition': ['Left Ventricular Dysfunction', 'ADHD', 'Birth Control'],
 'review': ['"It has no side effect, I take it in combination of Bystolic 5 Mg and Fish Oil"',
  '"My son is halfway through his fourth week of Intuniv. We became concerned when he began this last week, when he started taking the highest dose he will be on. For two days, he could hardly get out of bed, was very cranky, and slept for nearly 8 hours on a drive home from school vacation (very unusual for him.) I called his doctor on Monday morning and she said to stick it out a few days. See how he did at school, and with getting up in the morning. The last two days have been problem free. He is MUCH more agreeable than ever. He is less emotional (a good thing), less cranky. He is remembering all the things he should. Overall his behavior is better. \r\nWe have tried many different medications and so far this is the most effect

In [5]:
drug_sample = drug_dataset['train'].shuffle(seed=42).select(range(1000))
drug_sample[:3]

{'Unnamed: 0': [87571, 178045, 80482],
 'drugName': ['Naproxen', 'Duloxetine', 'Mobic'],
 'condition': ['Gout, Acute', 'ibromyalgia', 'Inflammatory Conditions'],
 'review': ['"like the previous person mention, I&#039;m a strong believer of aleve, it works faster for my gout than the prescription meds I take. No more going to the doctor for refills.....Aleve works!"',
  '"I have taken Cymbalta for about a year and a half for fibromyalgia pain. It is great\r\nas a pain reducer and an anti-depressant, however, the side effects outweighed \r\nany benefit I got from it. I had trouble with restlessness, being tired constantly,\r\ndizziness, dry mouth, numbness and tingling in my feet, and horrible sweating. I am\r\nbeing weaned off of it now. Went from 60 mg to 30mg and now to 15 mg. I will be\r\noff completely in about a week. The fibro pain is coming back, but I would rather deal with it than the side effects."',
  '"I have been taking Mobic for over a year with no side effects other than 

In [6]:
drug_dataset

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],
        num_rows: 161297
    })
    test: Dataset({
        features: ['Unnamed: 0', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],
        num_rows: 53766
    })
})

In [7]:
for split in drug_dataset.keys():
    assert len(drug_dataset[split]) == len(drug_dataset[split].unique('Unnamed: 0'))

In [8]:
drug_dataset = drug_dataset.rename_column(original_column_name='Unnamed: 0', new_column_name='patient_id')
drug_dataset

DatasetDict({
    train: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],
        num_rows: 161297
    })
    test: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],
        num_rows: 53766
    })
})

In [9]:
def lowercase_condition(example):
    return {'condition': example['condition'].lower()}

In [10]:
drug_dataset['train'][:3]

{'patient_id': [206461, 95260, 92703],
 'drugName': ['Valsartan', 'Guanfacine', 'Lybrel'],
 'condition': ['Left Ventricular Dysfunction', 'ADHD', 'Birth Control'],
 'review': ['"It has no side effect, I take it in combination of Bystolic 5 Mg and Fish Oil"',
  '"My son is halfway through his fourth week of Intuniv. We became concerned when he began this last week, when he started taking the highest dose he will be on. For two days, he could hardly get out of bed, was very cranky, and slept for nearly 8 hours on a drive home from school vacation (very unusual for him.) I called his doctor on Monday morning and she said to stick it out a few days. See how he did at school, and with getting up in the morning. The last two days have been problem free. He is MUCH more agreeable than ever. He is less emotional (a good thing), less cranky. He is remembering all the things he should. Overall his behavior is better. \r\nWe have tried many different medications and so far this is the most effect

In [11]:
def filter_nones(x):
    return x['condition'] is not None

In [12]:
drug_dataset = drug_dataset.filter(lambda x: x['condition'] is not None)
drug_dataset

DatasetDict({
    train: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],
        num_rows: 160398
    })
    test: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],
        num_rows: 53471
    })
})

In [13]:
drug_dataset = drug_dataset.map(lowercase_condition)
drug_dataset['train'][:3]

{'patient_id': [206461, 95260, 92703],
 'drugName': ['Valsartan', 'Guanfacine', 'Lybrel'],
 'condition': ['left ventricular dysfunction', 'adhd', 'birth control'],
 'review': ['"It has no side effect, I take it in combination of Bystolic 5 Mg and Fish Oil"',
  '"My son is halfway through his fourth week of Intuniv. We became concerned when he began this last week, when he started taking the highest dose he will be on. For two days, he could hardly get out of bed, was very cranky, and slept for nearly 8 hours on a drive home from school vacation (very unusual for him.) I called his doctor on Monday morning and she said to stick it out a few days. See how he did at school, and with getting up in the morning. The last two days have been problem free. He is MUCH more agreeable than ever. He is less emotional (a good thing), less cranky. He is remembering all the things he should. Overall his behavior is better. \r\nWe have tried many different medications and so far this is the most effect

In [14]:
def compute_review_length(example):
    return {'review_length': len(example['review'].split())}

In [15]:
drug_dataset = drug_dataset.map(compute_review_length)
drug_dataset['train'][0]

{'patient_id': 206461,
 'drugName': 'Valsartan',
 'condition': 'left ventricular dysfunction',
 'review': '"It has no side effect, I take it in combination of Bystolic 5 Mg and Fish Oil"',
 'rating': 9.0,
 'date': 'May 20, 2012',
 'usefulCount': 27,
 'review_length': 17}

In [16]:
drug_dataset['train'].sort('review_length')[:3]

{'patient_id': [111469, 13653, 53602],
 'drugName': ['Ledipasvir / sofosbuvir',
  'Amphetamine / dextroamphetamine',
  'Alesse'],
 'condition': ['hepatitis c', 'adhd', 'birth control'],
 'review': ['"Headache"', '"Great"', '"Awesome"'],
 'rating': [10.0, 10.0, 10.0],
 'date': ['February 3, 2015', 'October 20, 2009', 'November 23, 2015'],
 'usefulCount': [41, 3, 0],
 'review_length': [1, 1, 1]}

In [17]:
drug_dataset

DatasetDict({
    train: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
        num_rows: 160398
    })
    test: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
        num_rows: 53471
    })
})

In [18]:
drug_dataset = drug_dataset.filter(lambda x: x['review_length'] > 30)
print(drug_dataset.num_rows)

{'train': 138514, 'test': 46108}


In [19]:
import html
drug_dataset = drug_dataset.map(lambda x: {'review': [html.unescape(o) for o in x['review']]}, batched=True)

In [20]:
drug_dataset.num_rows

{'train': 138514, 'test': 46108}

In [21]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

def tokenize_function(examples):
    return tokenizer(examples['review'], truncation=True)



In [22]:
%time tokenized_dataset = drug_dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/46108 [00:00<?, ? examples/s]

Map: 100%|██████████| 46108/46108 [00:11<00:00, 4103.55 examples/s]

CPU times: user 16 s, sys: 515 ms, total: 16.5 s
Wall time: 11.3 s





In [23]:
%time tokenized_dataset = drug_dataset.map(tokenize_function, batched=False)

CPU times: user 19 ms, sys: 4.08 ms, total: 23.1 ms
Wall time: 24.7 ms


In [25]:
slow_tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=False)
def slow_tokenize_function(examples):
    return slow_tokenizer(examples['review'], truncation=True)

%time tokenized_dataset = drug_dataset.map(slow_tokenize_function, batched=True)

Map: 100%|██████████| 138514/138514 [02:00<00:00, 1149.17 examples/s]
Map: 100%|██████████| 46108/46108 [00:39<00:00, 1160.32 examples/s]

CPU times: user 2min 39s, sys: 919 ms, total: 2min 40s
Wall time: 2min 40s





In [26]:
%time tokenized_dataset = drug_dataset.map(slow_tokenize_function, batched=False)

Map: 100%|██████████| 138514/138514 [02:35<00:00, 890.88 examples/s] 
Map: 100%|██████████| 46108/46108 [00:51<00:00, 886.76 examples/s] 

CPU times: user 3min 24s, sys: 3.21 s, total: 3min 28s
Wall time: 3min 27s





In [27]:
def tokenize_and_split(examples):
    return tokenizer(
        examples["review"],
        truncation=True,
        max_length=128,
        return_overflowing_tokens=True,
    )

In [28]:
drug_dataset['train'][0]

{'patient_id': 95260,
 'drugName': 'Guanfacine',
 'condition': 'adhd',
 'review': '"My son is halfway through his fourth week of Intuniv. We became concerned when he began this last week, when he started taking the highest dose he will be on. For two days, he could hardly get out of bed, was very cranky, and slept for nearly 8 hours on a drive home from school vacation (very unusual for him.) I called his doctor on Monday morning and she said to stick it out a few days. See how he did at school, and with getting up in the morning. The last two days have been problem free. He is MUCH more agreeable than ever. He is less emotional (a good thing), less cranky. He is remembering all the things he should. Overall his behavior is better. \r\nWe have tried many different medications and so far this is the most effective."',
 'rating': 8.0,
 'date': 'April 27, 2010',
 'usefulCount': 192,
 'review_length': 141}

In [30]:
drug_dataset['train'][0]['review']

'"My son is halfway through his fourth week of Intuniv. We became concerned when he began this last week, when he started taking the highest dose he will be on. For two days, he could hardly get out of bed, was very cranky, and slept for nearly 8 hours on a drive home from school vacation (very unusual for him.) I called his doctor on Monday morning and she said to stick it out a few days. See how he did at school, and with getting up in the morning. The last two days have been problem free. He is MUCH more agreeable than ever. He is less emotional (a good thing), less cranky. He is remembering all the things he should. Overall his behavior is better. \r\nWe have tried many different medications and so far this is the most effective."'

In [29]:
result = tokenize_and_split(drug_dataset["train"][0])
[len(inp) for inp in result["input_ids"]]

[128, 49]

In [35]:
drug_dataset['train'][:3]

{'patient_id': [95260, 92703, 138000],
 'drugName': ['Guanfacine', 'Lybrel', 'Ortho Evra'],
 'condition': ['adhd', 'birth control', 'birth control'],
 'review': ['"My son is halfway through his fourth week of Intuniv. We became concerned when he began this last week, when he started taking the highest dose he will be on. For two days, he could hardly get out of bed, was very cranky, and slept for nearly 8 hours on a drive home from school vacation (very unusual for him.) I called his doctor on Monday morning and she said to stick it out a few days. See how he did at school, and with getting up in the morning. The last two days have been problem free. He is MUCH more agreeable than ever. He is less emotional (a good thing), less cranky. He is remembering all the things he should. Overall his behavior is better. \r\nWe have tried many different medications and so far this is the most effective."',
  '"I used to take another oral contraceptive, which had 21 pill cycle, and was very happy-

In [36]:
drug_dataset.set_format('pandas')
drug_dataset['train'][:3]

Unnamed: 0,patient_id,drugName,condition,review,rating,date,usefulCount,review_length
0,95260,Guanfacine,adhd,"""My son is halfway through his fourth week of ...",8.0,"April 27, 2010",192,141
1,92703,Lybrel,birth control,"""I used to take another oral contraceptive, wh...",5.0,"December 14, 2009",17,134
2,138000,Ortho Evra,birth control,"""This is my first time using any form of birth...",8.0,"November 3, 2015",10,89


In [37]:
drug_dataset

DatasetDict({
    train: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
        num_rows: 138514
    })
    test: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
        num_rows: 46108
    })
})

In [38]:
drug_dataset_clean = drug_dataset["train"].train_test_split(train_size=0.8, seed=42)


In [39]:
drug_dataset_clean

DatasetDict({
    train: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
        num_rows: 110811
    })
    test: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
        num_rows: 27703
    })
})

In [40]:
# Rename the default "test" split to "validation"
drug_dataset_clean["validation"] = drug_dataset_clean.pop("test")
# Add the "test" set to our `DatasetDict`
drug_dataset_clean["test"] = drug_dataset["test"]
drug_dataset_clean

DatasetDict({
    train: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
        num_rows: 110811
    })
    validation: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
        num_rows: 27703
    })
    test: Dataset({
        features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
        num_rows: 46108
    })
})

In [41]:
drug_dataset_clean.items()

dict_items([('train', Dataset({
    features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
    num_rows: 110811
})), ('validation', Dataset({
    features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
    num_rows: 27703
})), ('test', Dataset({
    features: ['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],
    num_rows: 46108
}))])

Semantic search

In [1]:
from datasets import load_dataset
issue_dataset = load_dataset('lewtun/github-issues', split='train')
issue_dataset

Downloading readme: 100%|██████████| 10.5k/10.5k [00:00<00:00, 26.0MB/s]
Repo card metadata block was not found. Setting CardData to empty.
Downloading data: 100%|██████████| 12.2M/12.2M [00:00<00:00, 12.9MB/s]
Generating train split: 100%|██████████| 3019/3019 [00:00<00:00, 16763.47 examples/s]


Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'timeline_url', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 3019
})

In [9]:
import pandas as pd
pd.DataFrame(issue_dataset)[:3].T

Unnamed: 0,0,1,2
url,https://api.github.com/repos/huggingface/datas...,https://api.github.com/repos/huggingface/datas...,https://api.github.com/repos/huggingface/datas...
repository_url,https://api.github.com/repos/huggingface/datasets,https://api.github.com/repos/huggingface/datasets,https://api.github.com/repos/huggingface/datasets
labels_url,https://api.github.com/repos/huggingface/datas...,https://api.github.com/repos/huggingface/datas...,https://api.github.com/repos/huggingface/datas...
comments_url,https://api.github.com/repos/huggingface/datas...,https://api.github.com/repos/huggingface/datas...,https://api.github.com/repos/huggingface/datas...
events_url,https://api.github.com/repos/huggingface/datas...,https://api.github.com/repos/huggingface/datas...,https://api.github.com/repos/huggingface/datas...
html_url,https://github.com/huggingface/datasets/pull/2955,https://github.com/huggingface/datasets/pull/2954,https://github.com/huggingface/datasets/pull/2952
id,1003999469,1003904803,1002704096
node_id,PR_kwDODunzps4sHuRu,PR_kwDODunzps4sHa8O,PR_kwDODunzps4sDU8S
number,2955,2954,2952
title,Update legacy Python image for CI tests in Linux,Run tests in parallel,Fix missing conda deps


In [16]:
issue_dataset = issue_dataset.filter(lambda x: x['is_pull_request'] == False and len(x['comments']) > 0)
issue_dataset

Filter: 100%|██████████| 3019/3019 [00:00<00:00, 11618.39 examples/s]


Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'timeline_url', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 808
})

In [21]:
columns = issue_dataset.column_names
columns_to_keep = ['title', 'body', 'html_url', 'comments']
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issue_dataset = issue_dataset.remove_columns(columns_to_remove)
issue_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 808
})

In [22]:
pd.DataFrame(issue_dataset)[:3]

Unnamed: 0,html_url,title,comments,body
0,https://github.com/huggingface/datasets/issues...,Protect master branch,"[Cool, I think we can do both :), @lhoestq now...",After accidental merge commit (91c55355b634d0d...
1,https://github.com/huggingface/datasets/issues...,Backwards compatibility broken for cached data...,[Hi ! I guess the caching mechanism should hav...,## Describe the bug\r\nAfter upgrading to data...
2,https://github.com/huggingface/datasets/issues...,OSCAR unshuffled_original_ko: NonMatchingSplit...,[I tried `unshuffled_original_da` and it is al...,## Describe the bug\r\n\r\nCannot download OSC...


In [23]:
issue_dataset.set_format('pandas')
df = issue_dataset[:]
df.head()

Unnamed: 0,html_url,title,comments,body
0,https://github.com/huggingface/datasets/issues...,Protect master branch,"[Cool, I think we can do both :), @lhoestq now...",After accidental merge commit (91c55355b634d0d...
1,https://github.com/huggingface/datasets/issues...,Backwards compatibility broken for cached data...,[Hi ! I guess the caching mechanism should hav...,## Describe the bug\r\nAfter upgrading to data...
2,https://github.com/huggingface/datasets/issues...,OSCAR unshuffled_original_ko: NonMatchingSplit...,[I tried `unshuffled_original_da` and it is al...,## Describe the bug\r\n\r\nCannot download OSC...
3,https://github.com/huggingface/datasets/issues...,load_dataset using default cache on Windows ca...,"[Hi @daqieq, thanks for reporting.\r\n\r\nUnfo...",## Describe the bug\r\nStandard process to dow...
4,https://github.com/huggingface/datasets/issues...,to_tf_dataset keeps a reference to the open da...,"[I did some investigation and, as it seems, th...",To reproduce:\r\n```python\r\nimport datasets ...


In [29]:
df['comments'][0].tolist()

['Cool, I think we can do both :)',
 '@lhoestq now the 2 are implemented.\r\n\r\nPlease note that for the the second protection, finally I have chosen to protect the master branch only from **merge commits** (see update comment above), so no need to disable/re-enable the protection on each release (direct commits, different from merge commits, can be pushed to the remote master branch; and eventually reverted without messing up the repo history).']

In [31]:
comments_df = df.explode('comments', ignore_index=True)
comments_df.head(3).T

Unnamed: 0,0,1,2
html_url,https://github.com/huggingface/datasets/issues...,https://github.com/huggingface/datasets/issues...,https://github.com/huggingface/datasets/issues...
title,Protect master branch,Protect master branch,Backwards compatibility broken for cached data...
comments,"Cool, I think we can do both :)",@lhoestq now the 2 are implemented.\r\n\r\nPle...,Hi ! I guess the caching mechanism should have...
body,After accidental merge commit (91c55355b634d0d...,After accidental merge commit (91c55355b634d0d...,## Describe the bug\r\nAfter upgrading to data...


In [32]:
from datasets import Dataset
comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2964
})

In [35]:
comments_dataset = comments_dataset.map(lambda x: {'comment_length': len(x['comments'].split())})
comments_dataset

Map: 100%|██████████| 2964/2964 [00:00<00:00, 14019.91 examples/s]


Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2964
})

In [37]:
comments_dataset = comments_dataset.filter(lambda x: x['comment_length'] > 15)
comments_dataset

Filter:   0%|          | 0/2964 [00:00<?, ? examples/s]

Filter: 100%|██████████| 2964/2964 [00:00<00:00, 71612.84 examples/s]


Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2175
})

In [39]:
comments_dataset[:5]

{'html_url': ['https://github.com/huggingface/datasets/issues/2945',
  'https://github.com/huggingface/datasets/issues/2943',
  'https://github.com/huggingface/datasets/issues/2943',
  'https://github.com/huggingface/datasets/issues/2943',
  'https://github.com/huggingface/datasets/issues/2943'],
 'title': ['Protect master branch',
  'Backwards compatibility broken for cached datasets that use `.filter()`',
  'Backwards compatibility broken for cached datasets that use `.filter()`',
  'Backwards compatibility broken for cached datasets that use `.filter()`',
  'Backwards compatibility broken for cached datasets that use `.filter()`'],
 'comments': ['@lhoestq now the 2 are implemented.\r\n\r\nPlease note that for the the second protection, finally I have chosen to protect the master branch only from **merge commits** (see update comment above), so no need to disable/re-enable the protection on each release (direct commits, different from merge commits, can be pushed to the remote master

In [40]:
def concatenate_text(example):
    return {'text':example['title'] + ' \n' + example['body'] + ' \n' + example['comments']}

In [41]:
comments_dataset = comments_dataset.map(concatenate_text)

Map: 100%|██████████| 2175/2175 [00:00<00:00, 8001.65 examples/s]


In [43]:
comments_dataset[0]

{'html_url': 'https://github.com/huggingface/datasets/issues/2945',
 'title': 'Protect master branch',
 'comments': '@lhoestq now the 2 are implemented.\r\n\r\nPlease note that for the the second protection, finally I have chosen to protect the master branch only from **merge commits** (see update comment above), so no need to disable/re-enable the protection on each release (direct commits, different from merge commits, can be pushed to the remote master branch; and eventually reverted without messing up the repo history).',
 'body': 'After accidental merge commit (91c55355b634d0dc73350a7ddee1a6776dbbdd69) into `datasets` master branch, all commits present in the feature branch were permanently added to `datasets` master branch history, as e.g.:\r\n- 00cc036fea7c7745cfe722360036ed306796a3f2\r\n- 13ae8c98602bbad8197de3b9b425f4c78f582af1\r\n- ...\r\n\r\nI propose to protect our master branch, so that we avoid we can accidentally make this kind of mistakes in the future:\r\n- [x] For Pul

In [44]:
from transformers import AutoTokenizer, AutoModel
model_ckpt = 'sentence-transformers/bert-base-nli-mean-tokens'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)



In [46]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [45]:
def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

In [48]:
def get_embeddings(text_list):
    encoded_input = tokenizer(text_list, padding=True, truncation=True, return_tensors='pt')
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

In [50]:
comments_dataset

Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length', 'text'],
    num_rows: 2175
})

In [51]:
embedding = get_embeddings(comments_dataset['text'][0])
embedding.shape

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


torch.Size([1, 768])

In [61]:
embedding.detach().cpu().numpy()[0].shape

(768,)

In [54]:
embeddings_dataset = comments_dataset.map(lambda x: {'embeddings': get_embeddings(x['text']).detach().cpu().numpy()[0]})

Map:   0%|          | 1/2175 [00:00<24:53,  1.46 examples/s]


RuntimeError: The size of tensor a (1672) must match the size of tensor b (512) at non-singleton dimension 1