You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to run this code below on GPU, where should I specify device and what is the command like ?
device='gpu' or device='cuda' and where should I be mentioning it ?
This is your old code bit :
from transformers import T5ForConditionalGeneration, T5Tokenizer
MODEL = "kiri-ai/t5-base-qa-summary-emotion"
TOKENIZER = "t5-base"
def generate(input_text, model_name: str = None, tokenizer_name: str = None):
# Refer to global variables
global model
global tokenizer
# Setup
# Initialise model
if model == None:
# Use the default model
if model_name == None:
model = T5ForConditionalGeneration.from_pretrained(MODEL)
# Use the user defined model
else:
model = T5ForConditionalGeneration.from_pretrained(model_name)
# Initialise tokenizer
if tokenizer == None:
# Use the default tokenizer
if tokenizer_name == None:
tokenizer = T5Tokenizer.from_pretrained(TOKENIZER)
# Use the user defined tokenizer
else:
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name)
is_list = False
if isinstance(input_text, list):
is_list = True
features = tokenizer(input_text, padding=True, return_tensors='pt')
tokens = model.generate(input_ids=features['input_ids'],
attention_mask=features['attention_mask'], max_length=512)
if is_list:
return [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in tokens]
else:
return tokenizer.decode(tokens[0], skip_special_tokens=True)
def process_item(item):
return f"emotion: {item}"
def emotion(input_text, model_name: str = None, tokenizer_name: str = None):
if isinstance(input_text, list):
input_text = [process_item(item) for item in input_text]
else:
input_text = process_item(input_text)
return generate(input_text, model_name=model_name,
tokenizer_name=tokenizer_name)
Best,
Chirag
The text was updated successfully, but these errors were encountered:
Hi
I am trying to run this code below on GPU, where should I specify device and what is the command like ?
device='gpu' or device='cuda' and where should I be mentioning it ?
This is your old code bit :
Best,
Chirag
The text was updated successfully, but these errors were encountered: