In [1]:
from tokenizers import Tokenizer
from bastionlab.polars.policy import Policy, Aggregation, Log
from bastionlab.polars import train_test_split
import polars as pl
from bastionlab import Connection
from bastionlab.tokenizers import RemoteTokenizer
from bastionlab.torch.remote_torch import RemoteDataset
from bastionlab.torch.learner import RemoteLearner


file_path = "./SMSSpamCollection"

# Read CSV file using Polars and rename columns with `text`, `label`
df = pl.read_csv(file_path, has_header=False, sep="\t", new_columns=["label", "text"])

# Transform `spam` labels to `1` and `0` for any other column label
df = df.with_column(
    pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label")
)

# View the first few elements of the DataFrame
df.head()

  from .autonotebook import tqdm as notebook_tqdm


label,text
i64,str
0,"""Go until juron..."
0,"""Ok lar... Joki..."
1,"""Free entry in ..."
0,"""U dun say so e..."
0,"""Nah I don't th..."


In [2]:
tokenizer = RemoteTokenizer.from_hugging_face_pretrained("distilbert-base-uncased")
tokenizer.enable_padding(length=32)
tokenizer.enable_truncation(max_length=32)

connection = Connection("localhost")

policy = Policy(safe_zone=Aggregation(min_agg_size=10), unsafe_handling=Log())


rdf = connection.client.polars.send_df(df.limit(64), policy=policy)

# Split dataframe into train and test sets.
train_rdf, test_rdf = train_test_split(rdf, test_size=0.2)

# Create Inputs(train, test) RemoteSeries objects
train_inputs = train_rdf.column("text")
test_inputs = test_rdf.column("text")

# Create Label(train, test) RemoteSeries objects
train_label = train_rdf.column("label").to_tensor()
test_label = test_rdf.column("label").to_tensor()

# Tokenize `text` fields
train_ids, train_mask = tokenizer.encode(train_inputs)
test_ids, test_mask = tokenizer.encode(test_inputs)

# Create train RemoteDataset
train_rds = RemoteDataset(inputs=[train_ids, train_mask], label=train_label)

# Create test RemoteDataset
test_rds = RemoteDataset(inputs=[test_ids, test_mask], label=test_label)

ConnectionRefusedError: [Errno 111] Connection refused

In [17]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from bastionlab.torch.utils import MultipleOutputWrapper

model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=2,
    output_attentions=False,
    output_hidden_states=False,
    torchscript=True,
)
model = MultipleOutputWrapper(model, 0)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'pre_clas

In [18]:
from bastionlab.torch.optimizer_config import Adam

remote_learner = connection.client.torch.RemoteLearner(
    model,
    train_rds,
    max_batch_size=2,
    loss="cross_entropy",
    optimizer=Adam(lr=5e-5),
    model_name="DistilBERT",
)

remote_learner.fit(nb_epochs=1)  # , poll_delay=1.0)

Sending DistilBERT: 100%|████████████████████| 268M/268M [00:07<00:00, 36.5MB/s]


identifier: "8e43e41f447870bc47395f069147ae31ee6c710e471cc71fd45967ca55dd733a"
name: "DistilBERT"

{"inputs": [{"identifier": "4b703fa2-a9cf-41e9-aff3-417e6c239a9b"},{"identifier": "86e3fe50-adeb-4da2-b97b-674d85c6b474"}], "label": {"identifier": "56ee0aef-165b-446b-94a5-3e19f3246751"}, "nb_samples": 51, "privacy_limit": -1.0}




GRPCException: Connection to the gRPC server failed: code=StatusCode.UNAVAILABLE message=Connection reset by peer