In [1]:
from IPython import get_ipython
from IPython.display import clear_output, display

ipython = get_ipython()
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
import os
from typing import List, Optional, Union, Dict, Tuple
from pathlib import Path 

import torch
from torch import Tensor
import numpy as np
import einops
from fancy_einsum import einsum
import circuitsvis as cv

import transformer_lens.utils as tl_utils

from transformer_lens import HookedTransformer
import transformer_lens.patching as patching

from transformers import AutoModelForCausalLM

from torch import Tensor
from jaxtyping import Float
import plotly.express as px

from functools import partial

from torchtyping import TensorType as TT

from path_patching_cm.path_patching import Node, IterNode, path_patch, act_patch
from path_patching_cm.ioi_dataset import IOIDataset, NAMES
from neel_plotly import imshow as imshow_n

from utils.visualization import imshow_p, plot_attention_heads, plot_attention
from utils.data_utils import generate_data_and_caches, UniversalDataset
from utils.metrics import compute_logit_diff, compute_probability_diff, compute_probability_mass, compute_rank_0_rate
from utils.visualization_utils import (
    plot_attention_heads,
    scatter_attention_and_contribution,
    get_attn_head_patterns
)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [3]:
torch.set_grad_enabled(False)

model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-160m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
)
model.set_use_hook_mlp_in(True)

config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/375M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


In [28]:
def get_positional_logits(
        logits: Float[Tensor, "batch seq d_vocab"],
        positions: Float[Tensor, "batch"] = None
)-> Float[Tensor, "batch d_vocab"]:
    """Gets the logits at the provided positions. If no positions are provided, the final logits are returned.

    Args:
        logits (torch.Tensor): Logits to use.
        positions (torch.Tensor): Positions to get logits at. This should be a tensor of shape (batch_size,).

    Returns:
        torch.Tensor: Logits at the provided positions.
    """
    if positions is None:
        return logits[:, -1, :]
    
    return logits[range(logits.size(0)), positions, :]


def compute_logit_diff(
        logits: Float[Tensor, "batch seq d_vocab"], 
        answer_token_indices: Float[Tensor, "batch num_answers"],
        positions: Float[Tensor, "batch"] = None,
        flags_tensor: torch.Tensor = None,
        per_prompt=False,
        mode="simple"
)-> Float[Tensor, "batch num_answers"]:
    """Computes the difference between a correct and incorrect logit (or mean of a group of logits) for each item in the batch.

    Takes the full logits, and the indices of the tokens to compare. These indices can be of multiple types as follows:

    - Simple: The tensor should be of shape (batch_size, 2), where the first index in the third dimension is the correct token index,
        and the second index is the incorrect token index.

    - Pairs: In this mode, answer_token_indices is a 3D tensor of shape (batch, num_pairs, 2). For each pair, you'll need to compute 
             the difference between the logits at the two indices, then average these differences across each pair for every batch item.

    - Groups: Here, answer_token_indices is also a 3D tensor of shape (batch, num_tokens, 2). The third dimension indicates group membership 
              (correct or incorrect). The mean logits for each group are calculated and then subtracted from each other.
              

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the tokens to compare.
        positions (torch.Tensor): Positions to get logits at. Should be one position per batch item.

    Returns:
        torch.Tensor: Difference between the logits of the provided tokens.
    """
    logits = get_positional_logits(logits, positions)
    
    # Mode 1: Simple
    if mode == "simple":
        correct_logits = logits[torch.arange(logits.size(0)), answer_token_indices[:, 0]]
        incorrect_logits = logits[torch.arange(logits.size(0)), answer_token_indices[:, 1]]
        logit_diff = correct_logits - incorrect_logits

    # Mode 2: Pairs
    elif mode == "pairs":
        pair_diffs = logits[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 0]] - \
                     logits[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 1]]
        logit_diff = pair_diffs.mean(dim=1)

    # Mode 3: Groups
    elif mode == "groups":
        assert flags_tensor is not None
        logit_diff = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_logits = logits[i, answer_token_indices[i]]

            # Calculate the logit difference using the correct/incorrect flags
            correct_logits = selected_logits[flags_tensor[i] == 1]
            incorrect_logits = selected_logits[flags_tensor[i] == -1]

            # Handle cases where there are no correct or incorrect logits
            if len(correct_logits) > 0:
                correct_mean = correct_logits.mean()
            else:
                correct_mean = 0

            if len(incorrect_logits) > 0:
                incorrect_mean = incorrect_logits.mean()
            else:
                incorrect_mean = 0

            logit_diff[i] = correct_mean - incorrect_mean

    else:
        raise ValueError("Invalid mode specified")

    return logit_diff.mean() if not per_prompt else logit_diff



import torch
import torch.nn.functional as F

def compute_probability_diff(
        logits: torch.Tensor, 
        answer_token_indices: torch.Tensor,
        positions: torch.Tensor = None,
        flags_tensor: torch.Tensor = None,
        per_prompt=False,
        mode="simple"
) -> torch.Tensor:
    """Computes the difference between probability of a correct and incorrect logit (or mean of a group of logits) for each item in the batch.

    Takes the full logits, and the indices of the tokens to compare. These indices can be of multiple types as follows:

    - Simple: The tensor should be of shape (batch_size, 2), where the first index in the third dimension is the correct token index,
        and the second index is the incorrect token index.

    - Pairs: In this mode, answer_token_indices is a 3D tensor of shape (batch, num_pairs, 2). For each pair, you'll need to compute 
             the difference between the probabilities at the two indices, then average these differences across each pair for every batch item.

    - Groups: Here, answer_token_indices is also a 3D tensor of shape (batch, num_tokens, 2). The third dimension indicates group membership 
              (correct or incorrect). The mean probabilities for each group are calculated and then subtracted from each other.
              

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the tokens to compare.
        positions (torch.Tensor): Positions to get logits at. Should be one position per batch item.

    Returns:
        torch.Tensor: Difference between the logits of the provided tokens.
    """
    logits = get_positional_logits(logits, positions)
    probabilities = torch.softmax(logits, dim=-1)  # Applying softmax to logits
    print(f"probabilities={probabilities.shape}")

    # Mode 1: Simple
    if mode == "simple":
        correct_probs = probabilities[torch.arange(logits.size(0)), answer_token_indices[:, 0]]
        incorrect_probs = probabilities[torch.arange(logits.size(0)), answer_token_indices[:, 1]]
        prob_diff = correct_probs - incorrect_probs

    # Mode 2: Pairs
    elif mode == "pairs":
        pair_diffs = probabilities[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 0]] - \
                     probabilities[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 1]]
        prob_diff = pair_diffs.mean(dim=1)

    # Mode 3: Groups
    elif mode == "groups":
        # Initialize tensors to store the probability differences for each batch item
        assert flags_tensor is not None
        prob_diff = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            # Select the probabilities for the token IDs of this batch item
            selected_probs = probabilities[i, answer_token_indices[i]]

            # Calculate the probability difference using the correct/incorrect flags
            correct_probs = selected_probs[flags_tensor[i] == 1]
            incorrect_probs = selected_probs[flags_tensor[i] == -1]

            # Handle cases where there are no correct or incorrect tokens
            if len(correct_probs) > 0:
                correct_mean = correct_probs.mean()
            else:
                correct_mean = 0

            if len(incorrect_probs) > 0:
                incorrect_mean = incorrect_probs.mean()
            else:
                incorrect_mean = 0

            prob_diff[i] = correct_mean - incorrect_mean

    # Mode 4: Group Sum
    elif mode == "group_sum":
        assert flags_tensor is not None
        prob_diff = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_probs = probabilities[i, answer_token_indices[i]]

            # Calculate the sum of probabilities using the correct/incorrect flags
            correct_sum = selected_probs[flags_tensor[i] == 1].sum()
            incorrect_sum = selected_probs[flags_tensor[i] == -1].sum()

            prob_diff[i] = incorrect_sum - correct_sum

    else:
        raise ValueError("Invalid mode specified")

    return prob_diff.mean() if not per_prompt else prob_diff


def compute_probability_mass(
        logits: torch.Tensor, 
        answer_token_indices: torch.Tensor,
        positions: torch.Tensor = None,
        flags_tensor: torch.Tensor = None,
        group="correct",
        mode="simple"
) -> torch.Tensor:
    logits = get_positional_logits(logits, positions)
    probabilities = torch.softmax(logits, dim=-1)

    # Determine the flag value based on the specified group
    flag_value = 1 if group == "correct" else -1

    # Mode logic
    if mode == "simple":
        selected_indices = answer_token_indices[:, 0] if group == "correct" else answer_token_indices[:, 1]
        group_probs = probabilities[torch.arange(logits.size(0)), selected_indices]

    elif mode == "pairs":
        group_probs = torch.zeros(logits.size(0), device=logits.device)
        for i in range(logits.size(0)):
            for pair in answer_token_indices[i]:
                selected_index = pair[0] if group == "correct" else pair[1]
                group_probs[i] += probabilities[i, selected_index]
            group_probs[i] /= answer_token_indices.size(1)

    elif mode == "groups":
        assert flags_tensor is not None
        group_probs = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_probs = probabilities[i, answer_token_indices[i]]
            group_probs[i] = selected_probs[flags_tensor[i] == flag_value].mean()

    elif mode == "group_sum":
        assert flags_tensor is not None
        group_probs = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_probs = probabilities[i, answer_token_indices[i]]
            group_probs[i] = selected_probs[flags_tensor[i] == flag_value].sum()

    else:
        raise ValueError("Invalid mode specified")

    return group_probs.mean()



def compute_rank_0_rate(
        logits: torch.Tensor, 
        answer_token_indices: torch.Tensor,
        positions: torch.Tensor = None,
        flags_tensor: torch.Tensor = None,
        group="correct",
        mode="simple"
) -> torch.Tensor:
    logits = get_positional_logits(logits, positions)
    probabilities = torch.softmax(logits, dim=-1)

    # Mode logic
    if mode == "simple":
        top_rank_indices = probabilities.argmax(dim=-1)
        correct_indices = answer_token_indices[:, 0] if group == "correct" else answer_token_indices[:, 1]
        rank_0_rate = (top_rank_indices == correct_indices).float().mean()

    elif mode == "pairs":
        rank_0_rate = torch.zeros(logits.size(0), device=logits.device)
        for i in range(logits.size(0)):
            for pair in answer_token_indices[i]:
                top_rank_index = probabilities[i].argmax()
                correct_index = pair[0] if group == "correct" else pair[1]
                rank_0_rate[i] += (top_rank_index == correct_index).float()
            rank_0_rate[i] /= answer_token_indices.size(1)

    elif mode == "groups":
        assert flags_tensor is not None
        rank_0_rate = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_probs = probabilities[i, answer_token_indices[i]]
            top_rank_id = selected_probs.argmax()
            rank_0_rate[i] = (flags_tensor[i, top_rank_id] == 1).float() if group == "correct" else \
                             (flags_tensor[i, top_rank_id] == -1).float()

    else:
        raise ValueError("Invalid mode specified")

    return rank_0_rate.mean()



## IOI

### Old

In [81]:
def _logits_to_mean_logit_diff(logits: Float[Tensor, "batch seq d_vocab"], ioi_dataset: IOIDataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs]
    print(io_logits.shape)
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [82]:
N = 70
ioi_dataset, abc_dataset, _, _, _ = generate_data_and_caches(model, N, verbose=True)
clean_toks = ioi_dataset.toks.to(device)
corrupted_toks = abc_dataset.toks.to(device)

Average logit diff (IOI dataset): 4.1336
Average logit diff (ABC dataset): -4.0758


In [83]:
logits = model(clean_toks)
_logits_to_mean_logit_diff(logits, ioi_dataset, per_prompt=True)

torch.Size([70])


tensor([ 3.3623,  4.5626,  4.4264,  5.5464,  4.9847,  3.3949,  6.6630,  4.6075,
         4.1188,  0.7323,  3.5184,  5.4281, -1.2406,  3.1855,  3.9544,  4.9033,
         4.9604, -0.5812,  6.0159,  4.4829,  4.9364,  3.5120,  3.6026,  5.8260,
         1.5506,  3.6591,  5.1734,  7.8747,  1.9717,  3.0304,  1.3404,  5.6763,
         6.4933,  3.3304,  5.4253,  3.1553,  1.6709,  5.3956,  3.8436,  1.4698,
         7.1663,  7.1676,  6.1948,  5.9164,  6.3674,  5.1585,  6.8696,  3.2112,
         0.9600,  5.3084,  2.2612,  4.1867,  2.8197,  6.8392,  6.7878,  4.9921,
         4.7495,  4.8143,  3.2216,  5.5183,  0.8875,  4.1779,  3.3880,  3.5489,
         0.9434,  4.0729,  3.7018,  2.4539,  7.2850,  2.4210], device='cuda:0')

### New

In [84]:
ds = UniversalDataset.from_ioi(model, 70)

Average logit diff (IOI dataset): 4.1336
Average logit diff (ABC dataset): -4.0758


In [85]:
logits = model(ds.toks)

In [86]:
compute_logit_diff(logits, ds.answer_toks, ds.positions, per_prompt=True)

tensor([ 3.3623,  4.5626,  4.4264,  5.5464,  4.9847,  3.3949,  6.6630,  4.6075,
         4.1188,  0.7323,  3.5184,  5.4281, -1.2406,  3.1855,  3.9544,  4.9033,
         4.9604, -0.5812,  6.0159,  4.4829,  4.9364,  3.5120,  3.6026,  5.8260,
         1.5506,  3.6591,  5.1734,  7.8747,  1.9717,  3.0304,  1.3404,  5.6763,
         6.4933,  3.3304,  5.4253,  3.1553,  1.6709,  5.3956,  3.8436,  1.4698,
         7.1663,  7.1676,  6.1948,  5.9164,  6.3674,  5.1585,  6.8696,  3.2112,
         0.9600,  5.3084,  2.2612,  4.1867,  2.8197,  6.8392,  6.7878,  4.9921,
         4.7495,  4.8143,  3.2216,  5.5183,  0.8875,  4.1779,  3.3880,  3.5489,
         0.9434,  4.0729,  3.7018,  2.4539,  7.2850,  2.4210], device='cuda:0')

In [89]:
compute_probability_diff(logits, ds.answer_toks, ds.positions, per_prompt=True)

probabilities=torch.Size([70, 50304])


tensor([ 0.3522,  0.1489,  0.1367,  0.2285,  0.1738,  0.2953,  0.3568,  0.4153,
         0.2671,  0.0405,  0.6666,  0.4521, -0.0517,  0.2696,  0.2368,  0.2904,
         0.0940, -0.0523,  0.4264,  0.4820,  0.2183,  0.2214,  0.2022,  0.1831,
         0.1875,  0.0674,  0.4456,  0.2823,  0.1970,  0.0493,  0.3325,  0.2131,
         0.7092,  0.1755,  0.3654,  0.1416,  0.0983,  0.6070,  0.1499,  0.1279,
         0.7195,  0.4302,  0.1586,  0.2769,  0.2814,  0.1193,  0.5471,  0.1647,
         0.0067,  0.2246,  0.0681,  0.2199,  0.6249,  0.3559,  0.3219,  0.2941,
         0.2657,  0.1763,  0.3662,  0.1637,  0.1565,  0.1751,  0.1243,  0.4158,
         0.0082,  0.3939,  0.2468,  0.2145,  0.3334,  0.2217], device='cuda:0')

## Greater-Than

### Old

In [13]:
from data.greater_than_dataset import get_prob_diff, YearDataset, get_valid_years

In [39]:
ds_old = YearDataset(get_valid_years(model.tokenizer, 1100, 1800), 1000, Path("data/potential_nouns.txt"), model.tokenizer)

# def batch(iterable, n:int=1):
#    current_batch = []
#    for item in iterable:
#        current_batch.append(item)
#        if len(current_batch) == n:
#            yield current_batch
#            current_batch = []
#    if current_batch:
#        yield current_batch

# clean = list(batch(ds.good_sentences, 9))
# labels = list(batch(ds.years_YY, 9))
# corrupted = list(batch(ds.bad_sentences, 9))

In [40]:
IDX = 768
#model.to_str_tokens(ds.good_toks[IDX]), model.to_str_tokens(ds.bad_toks[IDX])

In [41]:
import torch

def prepare_indices_for_prob_diff(tokenizer, years):
    """
    Prepares two tensors for use with the compute_probability_diff function in 'groups' mode.

    Args:
        tokenizer (PreTrainedTokenizer): Tokenizer to convert years to token indices.
        years (torch.Tensor): Tensor containing the year for each prompt in the batch.

    Returns:
        torch.Tensor, torch.Tensor: Two tensors, one for token IDs and one for correct/incorrect flags.
    """

    # Get the indices for years 00 to 99
    year_indices = get_year_indices(tokenizer)  # Tensor of size 100 with token IDs for years

    # Prepare tensors to store token IDs and correct/incorrect flags
    token_ids_tensor = year_indices.repeat(years.size(0), 1)  # Repeat the year_indices for each batch item
    flags_tensor = torch.zeros_like(token_ids_tensor)  # Initialize the flags tensor with zeros

    for i, year in enumerate(years):
        # Mark years greater than the given year as correct (1)
        flags_tensor[i, year + 1:] = 1
        # Mark years less than or equal to the given year as incorrect (-1)
        flags_tensor[i, :year + 1] = -1

    return token_ids_tensor, flags_tensor



In [42]:
#input_length = 1 + len(model.tokenizer(ds.good_sentences[0])[0])
prob_diff = get_prob_diff(model.tokenizer)

In [43]:
from utils.circuit_utils import run_with_batches

clean_logits = run_with_batches(model, ds_old.good_toks.to(device), batch_size=20, max_seq_len=12)
corrupted_logits = run_with_batches(model, ds_old.bad_toks.to(device), batch_size=20, max_seq_len=12)

In [44]:
prob_diff(clean_logits,ds_old.years_YY)

tensor(-0.8363, device='cuda:0')

### New

In [29]:
ds = UniversalDataset.from_greater_than(model, 1000)

In [30]:
logits = model(ds.toks)

In [32]:
compute_probability_diff(clean_logits, ds.answer_toks, flags_tensor=ds.group_flags, mode="group_sum")

probabilities=torch.Size([1000, 50304])


tensor(-0.8294, device='cuda:0')

In [None]:
logit_diff = compute_logit_diff(logits=clean_logits, answer_token_indices=answer_tokens, flags_tensor=group_flags, mode="groups")
probability_diff = compute_probability_diff(logits=clean_logits, answer_token_indices=answer_tokens, flags_tensor=group_flags, mode="group_sum")
probability_mass = compute_probability_mass(logits=clean_logits, answer_token_indices=answer_tokens, flags_tensor=group_flags, mode="groups", group="correct")
rank_0_rate = compute_rank_0_rate(logits=clean_logits, answer_token_indices=answer_tokens, flags_tensor=group_flags, mode="groups", group="correct")

## Sentiment

In [45]:
from data.sentiment_datasets import get_dataset, PromptType, get_prompts
from utils.circuit_analysis import get_logit_diff as get_logit_diff_ca

### Classification

#### Old

In [46]:
ds_type = PromptType.CLASSIFICATION_4

In [47]:
ds = get_dataset(model, device, prompt_type=ds_type)
ds.all_prompts[0]

Reading prompts from config and filtering


'I thought this movie was amazing, I loved it. The acting was awesome, the plot was beautiful, and overall the movie was just very good. Review Sentiment:'

In [48]:
ds.clean_tokens.shape

torch.Size([22, 35])

In [49]:
ds.answer_tokens.shape

torch.Size([22, 1, 2])

In [50]:
clean_logits = model(ds.clean_tokens.to(device))
corrupted_logits = model(ds.corrupted_tokens.to(device))

In [51]:
ds.answer_tokens.shape

torch.Size([22, 1, 2])

In [52]:
from utils.metrics import CircuitMetric
logit_diff_metric = CircuitMetric("logit_diff_multi", partial(get_logit_diff_ca, answer_tokens=ds.answer_tokens))

In [53]:
logit_diff_metric(clean_logits), logit_diff_metric(corrupted_logits)

(tensor(0.3842, device='cuda:0'), tensor(-0.3842, device='cuda:0'))

#### New

In [61]:
ds = UniversalDataset.from_sentiment(model, "class")

Reading prompts from config and filtering


In [63]:
logits = model(ds.toks)
flipped_logits = model(ds.flipped_toks)

In [64]:
compute_logit_diff(logits, ds.answer_toks, mode="pairs"), compute_logit_diff(flipped_logits, ds.answer_toks, mode="pairs")

(tensor(0.3842, device='cuda:0'), tensor(-0.3842, device='cuda:0'))

### Continuation

#### Old

In [65]:
ds_type = PromptType.COMPLETION_2

In [66]:
ds = get_dataset(model, device, prompt_type=ds_type)
ds.all_prompts[0]

Reading prompts from config and filtering


'I thought this movie was amazing, I loved it. The acting was awesome, the plot was beautiful, and overall it was just very'

In [67]:
ds.clean_tokens.shape

torch.Size([22, 28])

In [68]:
ds.answer_tokens.shape

torch.Size([22, 5, 2])

In [69]:
clean_logits = model(ds.clean_tokens.to(device))
corrupted_logits = model(ds.corrupted_tokens.to(device))

In [70]:
ds.answer_tokens.shape

torch.Size([22, 5, 2])

In [71]:
from utils.metrics import CircuitMetric
logit_diff_metric = CircuitMetric("logit_diff_multi", partial(get_logit_diff_ca, answer_tokens=ds.answer_tokens))

In [72]:
logit_diff_metric(clean_logits), logit_diff_metric(corrupted_logits)

(tensor(3.2004, device='cuda:0'), tensor(-3.2004, device='cuda:0'))

#### New

In [73]:
ds = UniversalDataset.from_sentiment(model, "cont")

Reading prompts from config and filtering


In [74]:
logits = model(ds.toks)
flipped_logits = model(ds.flipped_toks)

In [75]:
compute_logit_diff(logits, ds.answer_toks, mode="pairs"), compute_logit_diff(flipped_logits, ds.answer_toks, mode="pairs")

(tensor(3.2004, device='cuda:0'), tensor(-3.2004, device='cuda:0'))